diff --git a/rama/Cargo.toml b/rama/Cargo.toml index 2c792997..4ee6dcf7 100644 --- a/rama/Cargo.toml +++ b/rama/Cargo.toml @@ -16,12 +16,10 @@ matchit = "0.7.3" pin-project-lite = "0.2.13" tokio = { version = "1.33.0", features = ["net", "io-util"] } tokio-graceful = "0.1.5" -tower-async-layer = "0.1.1" -tower-async-service = "0.1.1" +tower-async = { version = "0.1.1", features = ["util"] } +tracing = "0.1.40" [dev-dependencies] tokio = { version = "1.33.0", features = ["full"] } tokio-test = "0.4.3" -tower-async = { version = "0.1.1", features = ["full"] } -tracing = "0.1.40" tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } diff --git a/rama/examples/tokio_tcp_echo_server.rs b/rama/examples/tokio_tcp_echo_server.rs index aa803fac..5d1de0d4 100644 --- a/rama/examples/tokio_tcp_echo_server.rs +++ b/rama/examples/tokio_tcp_echo_server.rs @@ -2,10 +2,9 @@ use std::time::Duration; use rama::graceful::{Shutdown, ShutdownGuardAdderLayer}; use rama::server::tcp::TcpListener; +use rama::service::spawn::SpawnLayer; use rama::stream::service::EchoService; -use rama::Service; -use tower_async::{service_fn, ServiceBuilder}; use tracing::metadata::LevelFilter; use tracing_subscriber::{fmt, prelude::*, EnvFilter}; @@ -23,22 +22,12 @@ async fn main() { let shutdown = Shutdown::default(); shutdown.spawn_task_fn(|guard| async { - let guard = guard.downgrade(); TcpListener::bind("127.0.0.1:8080") .await .expect("bind TCP Listener") - .serve(service_fn(|stream| async { - let guard = guard.clone(); - tokio::spawn(async move { - ServiceBuilder::new() - .layer(ShutdownGuardAdderLayer::new(guard)) - .service(EchoService::new()) - .call(stream) - .await - .expect("call EchoService"); - }); - Ok::<(), std::convert::Infallible>(()) - })) + .layer(ShutdownGuardAdderLayer::new(guard.downgrade())) + .layer(SpawnLayer::new()) + .serve::<_, EchoService, _>(EchoService::new()) .await .expect("serve incoming TCP connections"); }); diff --git a/rama/src/graceful.rs b/rama/src/graceful.rs index 2220ed2e..2387f558 100644 --- a/rama/src/graceful.rs +++ b/rama/src/graceful.rs @@ -1,14 +1,17 @@ -pub use tokio_graceful::*; +pub use tokio_graceful::{Shutdown, ShutdownGuard, WeakShutdownGuard}; -use crate::{state::Extendable, Layer, Service}; +use crate::{ + service::{Layer, Service}, + state::Extendable, +}; pub struct ShutdownGuardAdder { inner: S, - guard: ShutdownGuard, + guard: WeakShutdownGuard, } impl ShutdownGuardAdder { - fn new(inner: S, guard: ShutdownGuard) -> Self { + fn new(inner: S, guard: WeakShutdownGuard) -> Self { Self { inner, guard } } } @@ -22,7 +25,7 @@ where type Error = S::Error; async fn call(&mut self, mut request: Request) -> Result { - let guard = self.guard.clone(); + let guard = self.guard.clone().upgrade(); request.extensions_mut().insert(guard); self.inner.call(request).await } @@ -42,6 +45,6 @@ impl Layer for ShutdownGuardAdderLayer { type Service = ShutdownGuardAdder; fn layer(&self, inner: S) -> Self::Service { - ShutdownGuardAdder::new(inner, self.guard.clone().upgrade()) + ShutdownGuardAdder::new(inner, self.guard.clone()) } } diff --git a/rama/src/lib.rs b/rama/src/lib.rs index dc46bbcc..9848af83 100644 --- a/rama/src/lib.rs +++ b/rama/src/lib.rs @@ -1,9 +1,9 @@ #![feature(async_fn_in_trait)] +#![feature(return_type_notation)] +#![allow(incomplete_features)] pub mod graceful; pub mod server; +pub mod service; pub mod state; pub mod stream; - -pub use tower_async_layer::Layer; -pub use tower_async_service::Service; diff --git a/rama/src/server/tcp/listener.rs b/rama/src/server/tcp/listener.rs index 78e5ef07..5352dbdc 100644 --- a/rama/src/server/tcp/listener.rs +++ b/rama/src/server/tcp/listener.rs @@ -2,15 +2,19 @@ use std::{io, net::SocketAddr}; use tokio::net::{TcpListener as TokioTcpListener, ToSocketAddrs}; -use crate::Service; +use crate::service::{ + util::{Identity, Stack}, + Layer, Service, ServiceBuilder, +}; use super::TcpStream; -pub struct TcpListener { +pub struct TcpListener { inner: TokioTcpListener, + builder: ServiceBuilder, } -impl TcpListener { +impl TcpListener { /// Creates a new TcpListener, which will be bound to the specified address. /// /// The returned listener is ready for accepting connections. @@ -18,11 +22,14 @@ impl TcpListener { /// Binding with a port number of 0 will request that the OS assigns a port /// to this listener. The port allocated can be queried via the `local_addr` /// method. - pub async fn bind(addr: A) -> io::Result { + pub async fn bind(addr: A) -> io::Result { let inner = TokioTcpListener::bind(addr).await?; - Ok(TcpListener { inner }) + let builder = ServiceBuilder::new(); + Ok(TcpListener { inner, builder }) } +} +impl TcpListener { /// Returns the local address that this listener is bound to. /// /// This can be useful, for example, when binding to port 0 to figure out @@ -48,14 +55,32 @@ impl TcpListener { self.inner.set_ttl(ttl) } + /// Adds a layer to the service. + /// + /// This method can be used to add a middleware to the service. + pub fn layer(self, layer: M) -> TcpListener> + where + M: tower_async::layer::Layer, + { + TcpListener { + inner: self.inner, + builder: self.builder.layer(layer), + } + } +} + +impl TcpListener { /// Serve connections from this listener with the given service. /// /// This method will block the current listener for each incoming connection, /// the underlying service can choose to spawn a task to handle the accepted stream. - pub async fn serve(self, mut service: S) -> TcpServeResult + pub async fn serve(self, service: S) -> TcpServeResult where - S: Service, + L: Layer, + L::Service: Service, { + let mut service = self.builder.service(service); + loop { let (stream, _) = self.inner.accept().await?; let stream = TcpStream::new(stream); diff --git a/rama/src/service/mod.rs b/rama/src/service/mod.rs new file mode 100644 index 00000000..a6908383 --- /dev/null +++ b/rama/src/service/mod.rs @@ -0,0 +1,7 @@ +pub use tower_async::{service_fn, Layer, Service, ServiceBuilder}; + +pub mod util { + pub use tower_async::layer::util::{Identity, Stack}; +} + +pub mod spawn; diff --git a/rama/src/service/spawn.rs b/rama/src/service/spawn.rs new file mode 100644 index 00000000..a1a1b25f --- /dev/null +++ b/rama/src/service/spawn.rs @@ -0,0 +1,62 @@ +use crate::{ + graceful::ShutdownGuard, + service::{Layer, Service}, + state::Extendable, +}; + +pub struct SpawnService { + inner: S, +} + +impl SpawnService { + pub fn new(inner: S) -> Self { + Self { inner } + } +} + +impl Service for SpawnService +where + S: Service + Clone + Send + 'static, + S::Error: std::error::Error, + Request: Extendable + Send + 'static, +{ + type Response = (); + type Error = std::convert::Infallible; + + async fn call(&mut self, request: Request) -> Result { + let mut service = self.inner.clone(); + if let Some(guard) = request.extensions().get::() { + guard.clone().spawn_task(async move { + if let Err(err) = service.call(request).await { + tracing::error!( + error = &err as &dyn std::error::Error, + "graceful service error" + ); + } + }); + } else { + tokio::spawn(async move { + if let Err(err) = service.call(request).await { + tracing::error!(error = &err as &dyn std::error::Error, "service error"); + } + }); + } + Ok(()) + } +} + +pub struct SpawnLayer(()); + +impl SpawnLayer { + pub fn new() -> Self { + Self(()) + } +} + +impl Layer for SpawnLayer { + type Service = SpawnService; + + fn layer(&self, inner: S) -> Self::Service { + SpawnService::new(inner) + } +} diff --git a/rama/src/stream/service/echo.rs b/rama/src/stream/service/echo.rs index f6d2d306..d6d5b173 100644 --- a/rama/src/stream/service/echo.rs +++ b/rama/src/stream/service/echo.rs @@ -1,13 +1,13 @@ use std::io::Error; -use crate::{stream::Stream, Service}; +use crate::{service::Service, stream::Stream}; /// An async service which echoes the incoming bytes back on the same stream. /// /// # Example /// /// ```rust -/// use rama::{stream::service::EchoService, Service}; +/// use rama::{service::Service, stream::service::EchoService}; /// /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { diff --git a/rama/src/stream/service/forward.rs b/rama/src/stream/service/forward.rs index 058e65f8..f6fe7210 100644 --- a/rama/src/stream/service/forward.rs +++ b/rama/src/stream/service/forward.rs @@ -1,6 +1,6 @@ use std::{io::Error, pin::Pin}; -use crate::{stream::Stream, Service}; +use crate::{service::Service, stream::Stream}; /// Async service which forwards the incoming connection bytes to the given destination, /// and forwards the response back from the destination to the incoming connection. @@ -8,7 +8,7 @@ use crate::{stream::Stream, Service}; /// # Example /// /// ```rust -/// use rama::{stream::service::ForwardService, Service}; +/// use rama::{service::Service, stream::service::ForwardService}; /// /// # #[tokio::main] /// # async fn main() -> Result<(), Box> {