diff --git a/rama/Cargo.toml b/rama/Cargo.toml index 537b06d6..9660a758 100644 --- a/rama/Cargo.toml +++ b/rama/Cargo.toml @@ -15,6 +15,8 @@ http = "0.2.9" matchit = "0.7.3" pin-project-lite = "0.2.13" tokio = { version = "1.33.0", features = ["net", "io-util"] } +tokio-graceful = "0.1.5" +tower-async-layer = "0.1.1" tower-async-service = "0.1.1" [dev-dependencies] diff --git a/rama/src/graceful.rs b/rama/src/graceful.rs new file mode 100644 index 00000000..ef216a40 --- /dev/null +++ b/rama/src/graceful.rs @@ -0,0 +1,50 @@ +use tower_async_layer::Layer; +use tower_async_service::Service; + +pub use tokio_graceful::*; + +use crate::state::Extendable; + +pub struct ShutdownGuardAdder { + inner: S, + guard: ShutdownGuard, +} + +impl ShutdownGuardAdder { + fn new(inner: S, guard: ShutdownGuard) -> Self { + Self { inner, guard } + } +} + +impl Service for ShutdownGuardAdder +where + S: Service, + Request: Extendable, +{ + type Response = S::Response; + type Error = S::Error; + + async fn call(&mut self, mut request: Request) -> Result { + let guard = self.guard.clone(); + request.extensions_mut().insert(guard); + self.inner.call(request).await + } +} + +pub struct ShutdownGuardAdderLayer { + guard: ShutdownGuard, +} + +impl ShutdownGuardAdderLayer { + pub fn new(guard: ShutdownGuard) -> Self { + Self { guard } + } +} + +impl Layer for ShutdownGuardAdderLayer { + type Service = ShutdownGuardAdder; + + fn layer(&self, inner: S) -> Self::Service { + ShutdownGuardAdder::new(inner, self.guard.clone()) + } +} diff --git a/rama/src/lib.rs b/rama/src/lib.rs index d113bab8..4683fca0 100644 --- a/rama/src/lib.rs +++ b/rama/src/lib.rs @@ -1,5 +1,6 @@ #![feature(async_fn_in_trait)] +pub mod graceful; pub mod server; pub mod state; pub mod stream;