diff --git a/engineioxide/Cargo.toml b/engineioxide/Cargo.toml index d2257d6e..a886fbc9 100644 --- a/engineioxide/Cargo.toml +++ b/engineioxide/Cargo.toml @@ -23,13 +23,7 @@ bytes = "1.4.0" futures = "0.3.27" http = "0.2.9" http-body = "0.4.5" -hyper = { version = "0.14.25", features = [ - "http1", - "http2", - "server", - "stream", - "runtime", -] } +hyper = { version = "1.0.0-rc.4", features = ["http1", "client", "server"] } pin-project = "1.0.12" serde = { version = "1.0.155", features = ["derive"] } serde_json = "1.0.94" @@ -46,14 +40,10 @@ unicode-segmentation = { version = "1.10.1", optional = true } criterion = { version = "0.5.1", features = ["html_reports", "async_tokio"] } tokio = { version = "1.26.0", features = ["macros", "parking_lot"] } tracing-subscriber = { version = "0.3.16", features = ["env-filter"] } -hyper = { version = "0.14.25", features = [ - "http1", - "http2", - "server", - "stream", - "runtime", - "client", +hyper-util = { git = "/~https://github.com/hyperium/hyper-util", features = [ + "full", ] } +http-body-util = "0.1.0-rc.3" [features] default = ["v4"] diff --git a/engineioxide/src/service.rs b/engineioxide/src/service.rs index 6c70987d..9ffce4d9 100644 --- a/engineioxide/src/service.rs +++ b/engineioxide/src/service.rs @@ -12,7 +12,7 @@ use bytes::Bytes; use futures::future::{ready, Ready}; use http::{Method, Request}; use http_body::{Body, Empty}; -use hyper::{service::Service, Response}; +use hyper::{body::Incoming, Response}; use std::{ convert::Infallible, fmt::Debug, @@ -20,6 +20,7 @@ use std::{ sync::Arc, task::{Context, Poll}, }; +use tower::Service; /// A [`Service`] that handles engine.io requests as a middleware. /// If the request is not an engine.io request, it forwards it to the inner service. @@ -74,13 +75,10 @@ impl Clone for EngineIoService { } /// The service implementation for [`EngineIoService`]. -impl Service> for EngineIoService +impl Service> for EngineIoService where ResBody: Body + Send + 'static, - ReqBody: Body + Send + Unpin + 'static + Debug, - ::Error: Debug, - ::Data: Send, - S: Service, Response = Response>, + S: Service, Response = Response>, H: EngineIoHandler, { type Response = Response>; @@ -95,7 +93,7 @@ where /// Each request is parsed to a [`RequestInfo`] /// If the request is an `EngineIo` request, it is handled by the corresponding [`transport`](crate::transport). /// Otherwise, it is forwarded to the inner service. - fn call(&mut self, req: Request) -> Self::Future { + fn call(&mut self, req: Request) -> Self::Future { if req.uri().path().starts_with(&self.engine.config.req_path) { let engine = self.engine.clone(); match RequestInfo::parse(&req, &self.engine.config) { diff --git a/engineioxide/src/transport/polling/mod.rs b/engineioxide/src/transport/polling/mod.rs index 4b960954..3449e97a 100644 --- a/engineioxide/src/transport/polling/mod.rs +++ b/engineioxide/src/transport/polling/mod.rs @@ -122,7 +122,6 @@ pub async fn post_req( where H: EngineIoHandler, R: Body + Send + Unpin + 'static, - ::Error: std::fmt::Debug, ::Data: Send, B: Send + 'static, { diff --git a/engineioxide/src/transport/polling/payload/decoder.rs b/engineioxide/src/transport/polling/payload/decoder.rs index 1cc24dd4..426896d0 100644 --- a/engineioxide/src/transport/polling/payload/decoder.rs +++ b/engineioxide/src/transport/polling/payload/decoder.rs @@ -40,10 +40,7 @@ impl Payload { /// Polls the body stream for data and adds it to the chunk list in the state /// Returns an error if the packet length exceeds the maximum allowed payload size -async fn poll_body( - state: &mut Payload + Unpin>, - max_payload: u64, -) -> Result<(), Error> { +async fn poll_body(state: &mut Payload, max_payload: u64) -> Result<(), Error> { match state.body.data().await.transpose() { Ok(Some(data)) if state.current_payload_size + (data.remaining() as u64) <= max_payload => { state.current_payload_size += data.remaining() as u64; @@ -56,7 +53,7 @@ async fn poll_body( Ok(()) } Err(e) => { - debug!("error reading body stream: {:?}", e); + // debug!("error reading body stream: {:?}", e); Err(Error::HttpErrorResponse(StatusCode::BAD_REQUEST)) } } @@ -64,7 +61,7 @@ async fn poll_body( #[cfg(feature = "v4")] pub fn v4_decoder( - body: impl http_body::Body + Unpin, + body: impl http_body::Body + Unpin, max_payload: u64, ) -> impl Stream> { use super::PACKET_SEPARATOR_V4; @@ -113,7 +110,7 @@ pub fn v4_decoder( #[cfg(feature = "v3")] pub fn v3_binary_decoder( - body: impl http_body::Body + Unpin, + body: impl http_body::Body + Unpin, max_payload: u64, ) -> impl Stream> { use std::io::Read; @@ -205,7 +202,7 @@ pub fn v3_binary_decoder( #[cfg(feature = "v3")] pub fn v3_string_decoder( - body: impl http_body::Body + Unpin, + body: impl http_body::Body + Unpin, max_payload: u64, ) -> impl Stream> { use std::io::ErrorKind; @@ -372,7 +369,7 @@ mod tests { const DATA: &[u8] = "4foo\x1e4€f\x1e4fo".as_bytes(); for i in 1..DATA.len() { println!("payload stream v4 chunk size: {i}"); - let stream = hyper::Body::wrap_stream(futures::stream::iter( + let stream = http_body_util::StreamBody::new(futures::stream::iter( DATA.chunks(i).map(Ok::<_, std::convert::Infallible>), )); let payload = v4_decoder(stream, MAX_PAYLOAD); @@ -400,7 +397,7 @@ mod tests { const DATA: &[u8] = "4foo\x1e4€f\x1e4fo".as_bytes(); const MAX_PAYLOAD: u64 = 3; for i in 1..DATA.len() { - let stream = hyper::Body::wrap_stream(futures::stream::iter( + let stream = http_body_util::StreamBody::new(futures::stream::iter( DATA.chunks(i).map(Ok::<_, std::convert::Infallible>), )); let payload = v4_decoder(stream, MAX_PAYLOAD); @@ -463,7 +460,7 @@ mod tests { const DATA: &[u8] = "4:4foo3:4€f11:4baaaaaaaar".as_bytes(); for i in 1..DATA.len() { println!("payload stream v3 chunk size: {i}"); - let stream = hyper::Body::wrap_stream(futures::stream::iter( + let stream = http_body_util::StreamBody::new(futures::stream::iter( DATA.chunks(i).map(Ok::<_, std::convert::Infallible>), )); let payload = v3_string_decoder(stream, MAX_PAYLOAD); @@ -497,7 +494,7 @@ mod tests { for i in 1..PAYLOAD.len() { println!("payload stream v3 chunk size: {i}"); - let stream = hyper::Body::wrap_stream(futures::stream::iter( + let stream = http_body_util::StreamBody::new(futures::stream::iter( PAYLOAD.chunks(i).map(Ok::<_, std::convert::Infallible>), )); let payload = v3_binary_decoder(stream, MAX_PAYLOAD); @@ -521,7 +518,7 @@ mod tests { const DATA: &[u8] = "4:4foo3:4€f11:4baaaaaaaar".as_bytes(); const MAX_PAYLOAD: u64 = 3; for i in 1..DATA.len() { - let stream = hyper::Body::wrap_stream(futures::stream::iter( + let stream = http_body_util::StreamBody::new(futures::stream::iter( DATA.chunks(i).map(Ok::<_, std::convert::Infallible>), )); let payload = v3_binary_decoder(stream, MAX_PAYLOAD); @@ -530,7 +527,7 @@ mod tests { assert!(matches!(packet, Err(Error::PayloadTooLarge))); } for i in 1..DATA.len() { - let stream = hyper::Body::wrap_stream(futures::stream::iter( + let stream = http_body_util::StreamBody::new(futures::stream::iter( DATA.chunks(i).map(Ok::<_, std::convert::Infallible>), )); let payload = v3_string_decoder(stream, MAX_PAYLOAD); diff --git a/engineioxide/src/transport/polling/payload/mod.rs b/engineioxide/src/transport/polling/payload/mod.rs index 064a4c9d..f7b7f6db 100644 --- a/engineioxide/src/transport/polling/payload/mod.rs +++ b/engineioxide/src/transport/polling/payload/mod.rs @@ -21,7 +21,7 @@ const STRING_PACKET_IDENTIFIER_V3: u8 = 0x00; const BINARY_PACKET_IDENTIFIER_V3: u8 = 0x01; pub fn decoder( - body: Request + Unpin>, + body: Request, #[allow(unused_variables)] protocol: ProtocolVersion, max_payload: u64, ) -> impl Stream> { diff --git a/engineioxide/src/transport/ws.rs b/engineioxide/src/transport/ws/mod.rs similarity index 96% rename from engineioxide/src/transport/ws.rs rename to engineioxide/src/transport/ws/mod.rs index 7113cb9f..910d0788 100644 --- a/engineioxide/src/transport/ws.rs +++ b/engineioxide/src/transport/ws/mod.rs @@ -33,6 +33,9 @@ use crate::{ DisconnectReason, Socket, SocketReq, }; +mod tokio_io; +use tokio_io::TokioIo; + /// Upgrade a websocket request to create a websocket connection. /// /// If a sid is provided in the query it means that is is upgraded from an existing HTTP polling request. In this case @@ -77,7 +80,7 @@ async fn on_init( sid: Option, req_data: SocketReq, ) -> Result<(), Error> { - let ws_init = move || WebSocketStream::from_raw_socket(conn, Role::Server, None); + let ws_init = move || WebSocketStream::from_raw_socket(TokioIo::new(conn), Role::Server, None); let (socket, ws) = if let Some(sid) = sid { match engine.get_socket(sid) { None => return Err(Error::UnknownSessionID(sid)), @@ -124,7 +127,7 @@ async fn on_init( /// Forwards all packets received from a websocket to a EngineIo [`Socket`] async fn forward_to_handler( engine: &Arc>, - mut rx: SplitStream>, + mut rx: SplitStream>>, socket: &Arc>, ) -> Result<(), Error> { while let Some(msg) = rx.try_next().await? { @@ -161,7 +164,7 @@ async fn forward_to_handler( /// The websocket stream is flushed only when the internal channel is drained fn forward_to_socket( socket: Arc>, - mut tx: SplitSink, Message>, + mut tx: SplitSink>, Message>, ) -> JoinHandle<()> { // Pipe between websocket and internal socket channel tokio::spawn(async move { @@ -210,7 +213,7 @@ fn forward_to_socket( /// Send a Engine.IO [`OpenPacket`] to initiate a websocket connection async fn init_handshake( sid: Sid, - ws: &mut WebSocketStream, + ws: &mut WebSocketStream>, config: &EngineIoConfig, ) -> Result<(), Error> { let packet = Packet::Open(OpenPacket::new(TransportType::Websocket, sid, config)); @@ -245,7 +248,7 @@ async fn init_handshake( async fn upgrade_handshake( protocol: ProtocolVersion, socket: &Arc>, - ws: &mut WebSocketStream, + ws: &mut WebSocketStream>, ) -> Result<(), Error> { debug!("websocket connection upgrade"); diff --git a/engineioxide/src/transport/ws/tokio_io.rs b/engineioxide/src/transport/ws/tokio_io.rs new file mode 100644 index 00000000..0f7c10b9 --- /dev/null +++ b/engineioxide/src/transport/ws/tokio_io.rs @@ -0,0 +1,161 @@ +#![allow(dead_code)] +//! Tokio IO integration for hyper +use std::{ + pin::Pin, + task::{Context, Poll}, +}; + +use pin_project::pin_project; +/// A wrapping implementing hyper IO traits for a type that +/// implements Tokio's IO traits. +#[derive(Debug)] +#[pin_project] +pub struct TokioIo { + #[pin] + inner: T, +} + +impl TokioIo { + /// Wrap a type implementing Tokio's IO traits. + pub fn new(inner: T) -> Self { + Self { inner } + } + + /// Borrow the inner type. + pub fn inner(&self) -> &T { + &self.inner + } + + /// Consume this wrapper and get the inner type. + pub fn into_inner(self) -> T { + self.inner + } +} + +impl hyper::rt::Read for TokioIo +where + T: tokio::io::AsyncRead, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + mut buf: hyper::rt::ReadBufCursor<'_>, + ) -> Poll> { + let n = unsafe { + let mut tbuf = tokio::io::ReadBuf::uninit(buf.as_mut()); + match tokio::io::AsyncRead::poll_read(self.project().inner, cx, &mut tbuf) { + Poll::Ready(Ok(())) => tbuf.filled().len(), + other => return other, + } + }; + + unsafe { + buf.advance(n); + } + Poll::Ready(Ok(())) + } +} + +impl hyper::rt::Write for TokioIo +where + T: tokio::io::AsyncWrite, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + tokio::io::AsyncWrite::poll_flush(self.project().inner, cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx) + } + + fn is_write_vectored(&self) -> bool { + tokio::io::AsyncWrite::is_write_vectored(&self.inner) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + tokio::io::AsyncWrite::poll_write_vectored(self.project().inner, cx, bufs) + } +} + +impl tokio::io::AsyncRead for TokioIo +where + T: hyper::rt::Read, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + tbuf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + //let init = tbuf.initialized().len(); + let filled = tbuf.filled().len(); + let sub_filled = unsafe { + let mut buf = hyper::rt::ReadBuf::uninit(tbuf.unfilled_mut()); + + match hyper::rt::Read::poll_read(self.project().inner, cx, buf.unfilled()) { + Poll::Ready(Ok(())) => buf.filled().len(), + other => return other, + } + }; + + let n_filled = filled + sub_filled; + // At least sub_filled bytes had to have been initialized. + let n_init = sub_filled; + unsafe { + tbuf.assume_init(n_init); + tbuf.set_filled(n_filled); + } + + Poll::Ready(Ok(())) + } +} + +impl tokio::io::AsyncWrite for TokioIo +where + T: hyper::rt::Write, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + hyper::rt::Write::poll_write(self.project().inner, cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + hyper::rt::Write::poll_flush(self.project().inner, cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + hyper::rt::Write::poll_shutdown(self.project().inner, cx) + } + + fn is_write_vectored(&self) -> bool { + hyper::rt::Write::is_write_vectored(&self.inner) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + hyper::rt::Write::poll_write_vectored(self.project().inner, cx, bufs) + } +} diff --git a/engineioxide/tests/fixture.rs b/engineioxide/tests/fixture.rs index 20d36ecf..e7edd71f 100644 --- a/engineioxide/tests/fixture.rs +++ b/engineioxide/tests/fixture.rs @@ -3,12 +3,16 @@ use std::{ time::Duration, }; -use bytes::Buf; +use bytes::{Buf, Bytes}; use engineioxide::{config::EngineIoConfig, handler::EngineIoHandler, service::EngineIoService}; -use http::Request; -use hyper::Server; +use http_body_util::{BodyExt, Full}; +use hyper_util::{ + client::{connect::HttpConnector, legacy::Client}, + rt::TokioExecutor, + server::conn::auto, +}; use serde::{Deserialize, Serialize}; -use tokio::net::TcpStream; +use tokio::net::{TcpListener, TcpStream}; use tokio_tungstenite::{MaybeTlsStream, WebSocketStream}; /// An OpenPacket is used to initiate a connection @@ -30,9 +34,9 @@ pub async fn send_req( body: Option, ) -> String { let body = body - .map(|b| hyper::Body::from(b)) - .unwrap_or_else(hyper::Body::empty); - let req = Request::builder() + .map(|b| Full::new(Bytes::from(b))) + .unwrap_or(Full::new(Bytes::new())); + let req = hyper::Request::builder() .method(method) .uri(format!( "http://127.0.0.1:{port}/engine.io/?EIO=4&{}", @@ -40,8 +44,10 @@ pub async fn send_req( )) .body(body) .unwrap(); - let mut res = hyper::Client::new().request(req).await.unwrap(); - let body = hyper::body::aggregate(res.body_mut()).await.unwrap(); + + let client = Client::builder(hyper_util::rt::TokioExecutor::new()).build(HttpConnector::new()); + let mut res = client.request(req).await.unwrap(); + let body = res.body_mut().collect().await.unwrap().to_bytes(); String::from_utf8(body.chunk().to_vec()) .unwrap() .chars() @@ -63,7 +69,7 @@ pub async fn create_ws_connection(port: u16) -> WebSocketStream(handler: H, port: u16) { +pub async fn create_server(handler: H, port: u16) { let config = EngineIoConfig::builder() .ping_interval(Duration::from_millis(300)) .ping_timeout(Duration::from_millis(200)) @@ -74,7 +80,20 @@ pub fn create_server(handler: H, port: u16) { let svc = EngineIoService::with_config(handler, config); - let server = Server::bind(addr).serve(svc.into_make_service()); + // Tcp listener on addr + let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); + let listener = TcpListener::bind(addr).await.unwrap(); + + let local_addr = listener.local_addr().unwrap(); - tokio::spawn(server); + tokio::spawn(async move { + loop { + let (stream, _) = listener.accept().await.unwrap(); + tokio::task::spawn(async move { + let _ = auto::Builder::new(TokioExecutor::new()) + .serve_connection(stream, svc) + .await; + }); + } + }); }