diff --git a/Cargo.lock b/Cargo.lock index 960e32cb..dab2effb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -223,6 +223,15 @@ version = "1.0.7" source = "registry+/~https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "form_urlencoded" +version = "1.2.0" +source = "registry+/~https://github.com/rust-lang/crates.io-index" +checksum = "a62bc1cf6f830c2ec14a513a9fb124d0a213a629668a4186f329db21fe045652" +dependencies = [ + "percent-encoding", +] + [[package]] name = "futures" version = "0.3.29" @@ -528,6 +537,12 @@ dependencies = [ "serde", ] +[[package]] +name = "percent-encoding" +version = "2.3.0" +source = "registry+/~https://github.com/rust-lang/crates.io-index" +checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94" + [[package]] name = "pin-project-lite" version = "0.2.13" @@ -575,6 +590,8 @@ dependencies = [ "rcgen", "rustls", "rustls-pki-types", + "serde", + "serde_urlencoded", "tokio", "tokio-graceful", "tokio-rustls", @@ -741,18 +758,18 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" [[package]] name = "serde" -version = "1.0.191" +version = "1.0.192" source = "registry+/~https://github.com/rust-lang/crates.io-index" -checksum = "a834c4821019838224821468552240d4d95d14e751986442c816572d39a080c9" +checksum = "bca2a08484b285dcb282d0f67b26cadc0df8b19f8c12502c13d966bf9482f001" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.191" +version = "1.0.192" source = "registry+/~https://github.com/rust-lang/crates.io-index" -checksum = "46fa52d5646bce91b680189fe5b1c049d2ea38dabb4e2e7c8d00ca12cfbfbcfd" +checksum = "d6c7207fbec9faa48073f3e3074cbe553af6ea512d7c21ba46e434e70ea9fbc1" dependencies = [ "proc-macro2", "quote", @@ -770,6 +787,18 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_urlencoded" +version = "0.7.1" +source = "registry+/~https://github.com/rust-lang/crates.io-index" +checksum = "d3491c14715ca2294c4d6a88f15e84739788c1d030eed8c110436aafdaa2f3fd" +dependencies = [ + "form_urlencoded", + "itoa", + "ryu", + "serde", +] + [[package]] name = "sharded-slab" version = "0.1.7" diff --git a/rama/Cargo.toml b/rama/Cargo.toml index 26b0ebe5..00b89c30 100644 --- a/rama/Cargo.toml +++ b/rama/Cargo.toml @@ -15,6 +15,8 @@ http = "0.2.9" matchit = "0.7.3" pin-project-lite = "0.2.13" rustls = "0.22.0-alpha.3" +serde = "1.0.192" +serde_urlencoded = "0.7.1" tokio = { version = "1.33.0", features = ["net", "io-util"] } tokio-graceful = "0.1.5" tokio-rustls = "0.25.0-alpha.1" diff --git a/rama/examples/tokio_tcp_echo_server.rs b/rama/examples/tokio_tcp_echo_server.rs index e159e0be..8b1f956a 100644 --- a/rama/examples/tokio_tcp_echo_server.rs +++ b/rama/examples/tokio_tcp_echo_server.rs @@ -7,7 +7,8 @@ use rama::{ server::tcp::TcpListener, service::{limit::ConcurrentPolicy, Layer, Service}, state::Extendable, - stream::service::{BytesRWTrackerHandle, EchoService}, + stream::layer::BytesRWTrackerHandle, + stream::service::EchoService, }; use tracing::metadata::LevelFilter; diff --git a/rama/src/net/http/headers.rs b/rama/src/net/http/headers.rs new file mode 100644 index 00000000..00e590c0 --- /dev/null +++ b/rama/src/net/http/headers.rs @@ -0,0 +1,45 @@ +use http::{ + header::{AsHeaderName, GetAll}, + HeaderValue, Request, Response, +}; + +pub trait HeaderValueGetter { + fn header_value(&self, key: K) -> Option<&HeaderValue> + where + K: AsHeaderName; + fn header_values(&self, key: K) -> GetAll<'_, HeaderValue> + where + K: AsHeaderName; +} + +impl HeaderValueGetter for Request { + fn header_value(&self, key: K) -> Option<&HeaderValue> + where + K: AsHeaderName, + { + self.headers().get(key) + } + + fn header_values(&self, key: K) -> GetAll<'_, HeaderValue> + where + K: AsHeaderName, + { + self.headers().get_all(key) + } +} + +impl HeaderValueGetter for Response { + fn header_value(&self, key: K) -> Option<&HeaderValue> + where + K: AsHeaderName, + { + self.headers().get(key) + } + + fn header_values(&self, key: K) -> GetAll<'_, HeaderValue> + where + K: AsHeaderName, + { + self.headers().get_all(key) + } +} diff --git a/rama/src/net/http/mod.rs b/rama/src/net/http/mod.rs new file mode 100644 index 00000000..0a80dbac --- /dev/null +++ b/rama/src/net/http/mod.rs @@ -0,0 +1,2 @@ +mod headers; +pub use headers::HeaderValueGetter; diff --git a/rama/src/net/mod.rs b/rama/src/net/mod.rs new file mode 100644 index 00000000..6c8b1086 --- /dev/null +++ b/rama/src/net/mod.rs @@ -0,0 +1,6 @@ +mod tcp; +pub use tcp::TcpStream; + +pub mod http; + +pub use tokio::net::ToSocketAddrs; diff --git a/rama/src/net/tcp.rs b/rama/src/net/tcp.rs new file mode 100644 index 00000000..bff3f43b --- /dev/null +++ b/rama/src/net/tcp.rs @@ -0,0 +1,114 @@ +use std::{ + io, + net::SocketAddr, + pin::Pin, + task::{Context, Poll}, +}; + +use tokio::net::TcpStream as TokioTcpStream; + +use crate::{ + state::{Extendable, Extensions}, + stream::{AsyncRead, AsyncWrite, ReadBuf}, +}; + +pin_project_lite::pin_project! { + #[derive(Debug)] + pub struct TcpStream { + #[pin] + inner: S, + extensions: Extensions, + } +} + +impl TcpStream { + pub fn new(inner: S) -> Self { + Self { + inner, + extensions: Extensions::new(), + } + } + + pub fn into_parts(self) -> (S, Extensions) { + (self.inner, self.extensions) + } + + pub fn from_parts(inner: S, extensions: Extensions) -> Self { + Self { inner, extensions } + } +} + +impl TcpStream { + pub fn peer_addr(&self) -> io::Result { + self.inner.peer_addr() + } + + pub fn local_addr(&self) -> io::Result { + self.inner.local_addr() + } + + pub fn ttl(&self) -> io::Result { + self.inner.ttl() + } + + pub fn set_ttl(&self, ttl: u32) -> io::Result<()> { + self.inner.set_ttl(ttl) + } +} + +impl Extendable for TcpStream { + fn extensions(&self) -> &Extensions { + &self.extensions + } + + fn extensions_mut(&mut self) -> &mut Extensions { + &mut self.extensions + } +} + +impl AsyncRead for TcpStream +where + S: AsyncRead, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + self.project().inner.poll_read(cx, buf) + } +} + +impl AsyncWrite for TcpStream +where + S: AsyncWrite, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.project().inner.poll_write(cx, buf) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + self.project().inner.poll_write_vectored(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + self.inner.is_write_vectored() + } + + #[inline] + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_flush(cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_shutdown(cx) + } +} diff --git a/rama/src/server/http/header.rs b/rama/src/server/http/header.rs new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/rama/src/server/http/header.rs @@ -0,0 +1 @@ + diff --git a/rama/src/server/http/layer/mod.rs b/rama/src/server/http/layer/mod.rs new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/rama/src/server/http/layer/mod.rs @@ -0,0 +1 @@ + diff --git a/rama/src/server/http/mod.rs b/rama/src/server/http/mod.rs index 89bb5918..1f587155 100644 --- a/rama/src/server/http/mod.rs +++ b/rama/src/server/http/mod.rs @@ -1,3 +1,5 @@ +pub mod header; +pub mod layer; pub mod service; mod conn; diff --git a/rama/src/server/tcp/listener.rs b/rama/src/server/tcp/listener.rs index 71d3329b..3e4d52ca 100644 --- a/rama/src/server/tcp/listener.rs +++ b/rama/src/server/tcp/listener.rs @@ -92,9 +92,9 @@ impl TcpListener { /// This can be used to track the number of bytes read and written, /// by using the [`BytesRWTrackerHandle`] found in the extensions. /// - /// [`BytesRWTrackerHandle`]: crate::stream::service::BytesRWTrackerHandle - pub fn bytes_tracker(self) -> TcpListener> { - self.layer(crate::stream::service::BytesTrackerLayer::new()) + /// [`BytesRWTrackerHandle`]: crate::stream::layer::BytesRWTrackerHandle + pub fn bytes_tracker(self) -> TcpListener> { + self.layer(crate::stream::layer::BytesTrackerLayer::new()) } /// Fail requests that take longer than `timeout`. diff --git a/rama/src/stream/layer/mod.rs b/rama/src/stream/layer/mod.rs new file mode 100644 index 00000000..a83330ee --- /dev/null +++ b/rama/src/stream/layer/mod.rs @@ -0,0 +1,2 @@ +mod tracker; +pub use tracker::{BytesRWTrackerHandle, BytesTrackerLayer, BytesTrackerService}; diff --git a/rama/src/stream/service/tracker/bytes.rs b/rama/src/stream/layer/tracker/bytes.rs similarity index 100% rename from rama/src/stream/service/tracker/bytes.rs rename to rama/src/stream/layer/tracker/bytes.rs diff --git a/rama/src/stream/service/tracker/mod.rs b/rama/src/stream/layer/tracker/mod.rs similarity index 100% rename from rama/src/stream/service/tracker/mod.rs rename to rama/src/stream/layer/tracker/mod.rs diff --git a/rama/src/stream/mod.rs b/rama/src/stream/mod.rs index b8a0e5e2..52157700 100644 --- a/rama/src/stream/mod.rs +++ b/rama/src/stream/mod.rs @@ -1,5 +1,6 @@ pub use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +pub mod layer; pub mod service; pub trait Stream: AsyncRead + AsyncWrite {} diff --git a/rama/src/stream/service/mod.rs b/rama/src/stream/service/mod.rs index 04b6825e..8663ec2b 100644 --- a/rama/src/stream/service/mod.rs +++ b/rama/src/stream/service/mod.rs @@ -3,6 +3,3 @@ pub use echo::EchoService; mod forward; pub use forward::ForwardService; - -mod tracker; -pub use tracker::{BytesRWTrackerHandle, BytesTrackerLayer, BytesTrackerService};