Skip to content

Commit

Permalink
feat(engineio): hyper 1.0.0
Browse files Browse the repository at this point in the history
  • Loading branch information
Totodore committed Oct 22, 2023
1 parent 0f9e0c1 commit 2dd0d4a
Show file tree
Hide file tree
Showing 8 changed files with 221 additions and 54 deletions.
18 changes: 4 additions & 14 deletions engineioxide/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,7 @@ bytes = "1.4.0"
futures = "0.3.27"
http = "0.2.9"
http-body = "0.4.5"
hyper = { version = "0.14.25", features = [
"http1",
"http2",
"server",
"stream",
"runtime",
] }
hyper = { version = "1.0.0-rc.4", features = ["http1", "client", "server"] }
pin-project = "1.0.12"
serde = { version = "1.0.155", features = ["derive"] }
serde_json = "1.0.94"
Expand All @@ -46,14 +40,10 @@ unicode-segmentation = { version = "1.10.1", optional = true }
criterion = { version = "0.5.1", features = ["html_reports", "async_tokio"] }
tokio = { version = "1.26.0", features = ["macros", "parking_lot"] }
tracing-subscriber = { version = "0.3.16", features = ["env-filter"] }
hyper = { version = "0.14.25", features = [
"http1",
"http2",
"server",
"stream",
"runtime",
"client",
hyper-util = { git = "/~https://github.com/hyperium/hyper-util", features = [
"full",
] }
http-body-util = "0.1.0-rc.3"

[features]
default = ["v4"]
Expand Down
12 changes: 5 additions & 7 deletions engineioxide/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@ use bytes::Bytes;
use futures::future::{ready, Ready};
use http::{Method, Request};
use http_body::{Body, Empty};
use hyper::{service::Service, Response};
use hyper::{body::Incoming, Response};
use std::{
convert::Infallible,
fmt::Debug,
str::FromStr,
sync::Arc,
task::{Context, Poll},
};
use tower::Service;

/// A [`Service`] that handles engine.io requests as a middleware.
/// If the request is not an engine.io request, it forwards it to the inner service.
Expand Down Expand Up @@ -74,13 +75,10 @@ impl<S: Clone, H: EngineIoHandler> Clone for EngineIoService<H, S> {
}

/// The service implementation for [`EngineIoService`].
impl<ReqBody, ResBody, S, H> Service<Request<ReqBody>> for EngineIoService<H, S>
impl<ResBody, S, H> Service<Request<Incoming>> for EngineIoService<H, S>
where
ResBody: Body + Send + 'static,
ReqBody: Body + Send + Unpin + 'static + Debug,
<ReqBody as Body>::Error: Debug,
<ReqBody as Body>::Data: Send,
S: Service<Request<ReqBody>, Response = Response<ResBody>>,
S: Service<Request<Incoming>, Response = Response<ResBody>>,
H: EngineIoHandler,
{
type Response = Response<ResponseBody<ResBody>>;
Expand All @@ -95,7 +93,7 @@ where
/// Each request is parsed to a [`RequestInfo`]
/// If the request is an `EngineIo` request, it is handled by the corresponding [`transport`](crate::transport).
/// Otherwise, it is forwarded to the inner service.
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
fn call(&mut self, req: Request<Incoming>) -> Self::Future {
if req.uri().path().starts_with(&self.engine.config.req_path) {
let engine = self.engine.clone();
match RequestInfo::parse(&req, &self.engine.config) {
Expand Down
1 change: 0 additions & 1 deletion engineioxide/src/transport/polling/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ pub async fn post_req<R, B, H>(
where
H: EngineIoHandler,
R: Body + Send + Unpin + 'static,
<R as Body>::Error: std::fmt::Debug,
<R as Body>::Data: Send,
B: Send + 'static,
{
Expand Down
25 changes: 11 additions & 14 deletions engineioxide/src/transport/polling/payload/decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,7 @@ impl<B: http_body::Body> Payload<B> {

/// Polls the body stream for data and adds it to the chunk list in the state
/// Returns an error if the packet length exceeds the maximum allowed payload size
async fn poll_body(
state: &mut Payload<impl http_body::Body<Error = impl std::fmt::Debug> + Unpin>,
max_payload: u64,
) -> Result<(), Error> {
async fn poll_body(state: &mut Payload<Incoming>, max_payload: u64) -> Result<(), Error> {
match state.body.data().await.transpose() {
Ok(Some(data)) if state.current_payload_size + (data.remaining() as u64) <= max_payload => {
state.current_payload_size += data.remaining() as u64;
Expand All @@ -56,15 +53,15 @@ async fn poll_body(
Ok(())
}
Err(e) => {
debug!("error reading body stream: {:?}", e);
// debug!("error reading body stream: {:?}", e);
Err(Error::HttpErrorResponse(StatusCode::BAD_REQUEST))
}
}
}

#[cfg(feature = "v4")]
pub fn v4_decoder(
body: impl http_body::Body<Error = impl std::fmt::Debug> + Unpin,
body: impl http_body::Body + Unpin,
max_payload: u64,
) -> impl Stream<Item = Result<Packet, Error>> {
use super::PACKET_SEPARATOR_V4;
Expand Down Expand Up @@ -113,7 +110,7 @@ pub fn v4_decoder(

#[cfg(feature = "v3")]
pub fn v3_binary_decoder(
body: impl http_body::Body<Error = impl std::fmt::Debug> + Unpin,
body: impl http_body::Body + Unpin,
max_payload: u64,
) -> impl Stream<Item = Result<Packet, Error>> {
use std::io::Read;
Expand Down Expand Up @@ -205,7 +202,7 @@ pub fn v3_binary_decoder(

#[cfg(feature = "v3")]
pub fn v3_string_decoder(
body: impl http_body::Body<Error = impl std::fmt::Debug> + Unpin,
body: impl http_body::Body + Unpin,
max_payload: u64,
) -> impl Stream<Item = Result<Packet, Error>> {
use std::io::ErrorKind;
Expand Down Expand Up @@ -372,7 +369,7 @@ mod tests {
const DATA: &[u8] = "4foo\x1e4€f\x1e4fo".as_bytes();
for i in 1..DATA.len() {
println!("payload stream v4 chunk size: {i}");
let stream = hyper::Body::wrap_stream(futures::stream::iter(
let stream = http_body_util::StreamBody::new(futures::stream::iter(
DATA.chunks(i).map(Ok::<_, std::convert::Infallible>),
));
let payload = v4_decoder(stream, MAX_PAYLOAD);
Expand Down Expand Up @@ -400,7 +397,7 @@ mod tests {
const DATA: &[u8] = "4foo\x1e4€f\x1e4fo".as_bytes();
const MAX_PAYLOAD: u64 = 3;
for i in 1..DATA.len() {
let stream = hyper::Body::wrap_stream(futures::stream::iter(
let stream = http_body_util::StreamBody::new(futures::stream::iter(
DATA.chunks(i).map(Ok::<_, std::convert::Infallible>),
));
let payload = v4_decoder(stream, MAX_PAYLOAD);
Expand Down Expand Up @@ -463,7 +460,7 @@ mod tests {
const DATA: &[u8] = "4:4foo3:4€f11:4baaaaaaaar".as_bytes();
for i in 1..DATA.len() {
println!("payload stream v3 chunk size: {i}");
let stream = hyper::Body::wrap_stream(futures::stream::iter(
let stream = http_body_util::StreamBody::new(futures::stream::iter(
DATA.chunks(i).map(Ok::<_, std::convert::Infallible>),
));
let payload = v3_string_decoder(stream, MAX_PAYLOAD);
Expand Down Expand Up @@ -497,7 +494,7 @@ mod tests {

for i in 1..PAYLOAD.len() {
println!("payload stream v3 chunk size: {i}");
let stream = hyper::Body::wrap_stream(futures::stream::iter(
let stream = http_body_util::StreamBody::new(futures::stream::iter(
PAYLOAD.chunks(i).map(Ok::<_, std::convert::Infallible>),
));
let payload = v3_binary_decoder(stream, MAX_PAYLOAD);
Expand All @@ -521,7 +518,7 @@ mod tests {
const DATA: &[u8] = "4:4foo3:4€f11:4baaaaaaaar".as_bytes();
const MAX_PAYLOAD: u64 = 3;
for i in 1..DATA.len() {
let stream = hyper::Body::wrap_stream(futures::stream::iter(
let stream = http_body_util::StreamBody::new(futures::stream::iter(
DATA.chunks(i).map(Ok::<_, std::convert::Infallible>),
));
let payload = v3_binary_decoder(stream, MAX_PAYLOAD);
Expand All @@ -530,7 +527,7 @@ mod tests {
assert!(matches!(packet, Err(Error::PayloadTooLarge)));
}
for i in 1..DATA.len() {
let stream = hyper::Body::wrap_stream(futures::stream::iter(
let stream = http_body_util::StreamBody::new(futures::stream::iter(
DATA.chunks(i).map(Ok::<_, std::convert::Infallible>),
));
let payload = v3_string_decoder(stream, MAX_PAYLOAD);
Expand Down
2 changes: 1 addition & 1 deletion engineioxide/src/transport/polling/payload/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ const STRING_PACKET_IDENTIFIER_V3: u8 = 0x00;
const BINARY_PACKET_IDENTIFIER_V3: u8 = 0x01;

pub fn decoder(
body: Request<impl http_body::Body<Error = impl std::fmt::Debug> + Unpin>,
body: Request<impl http_body::Body + Unpin>,
#[allow(unused_variables)] protocol: ProtocolVersion,
max_payload: u64,
) -> impl Stream<Item = Result<Packet, Error>> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ use crate::{
DisconnectReason, Socket, SocketReq,
};

mod tokio_io;
use tokio_io::TokioIo;

/// Upgrade a websocket request to create a websocket connection.
///
/// If a sid is provided in the query it means that is is upgraded from an existing HTTP polling request. In this case
Expand Down Expand Up @@ -77,7 +80,7 @@ async fn on_init<H: EngineIoHandler>(
sid: Option<Sid>,
req_data: SocketReq,
) -> Result<(), Error> {
let ws_init = move || WebSocketStream::from_raw_socket(conn, Role::Server, None);
let ws_init = move || WebSocketStream::from_raw_socket(TokioIo::new(conn), Role::Server, None);
let (socket, ws) = if let Some(sid) = sid {
match engine.get_socket(sid) {
None => return Err(Error::UnknownSessionID(sid)),
Expand Down Expand Up @@ -124,7 +127,7 @@ async fn on_init<H: EngineIoHandler>(
/// Forwards all packets received from a websocket to a EngineIo [`Socket`]
async fn forward_to_handler<H: EngineIoHandler>(
engine: &Arc<EngineIo<H>>,
mut rx: SplitStream<WebSocketStream<Upgraded>>,
mut rx: SplitStream<WebSocketStream<TokioIo<Upgraded>>>,
socket: &Arc<Socket<H::Data>>,
) -> Result<(), Error> {
while let Some(msg) = rx.try_next().await? {
Expand Down Expand Up @@ -161,7 +164,7 @@ async fn forward_to_handler<H: EngineIoHandler>(
/// The websocket stream is flushed only when the internal channel is drained
fn forward_to_socket<H: EngineIoHandler>(
socket: Arc<Socket<H::Data>>,
mut tx: SplitSink<WebSocketStream<Upgraded>, Message>,
mut tx: SplitSink<WebSocketStream<TokioIo<Upgraded>>, Message>,
) -> JoinHandle<()> {
// Pipe between websocket and internal socket channel
tokio::spawn(async move {
Expand Down Expand Up @@ -210,7 +213,7 @@ fn forward_to_socket<H: EngineIoHandler>(
/// Send a Engine.IO [`OpenPacket`] to initiate a websocket connection
async fn init_handshake(
sid: Sid,
ws: &mut WebSocketStream<Upgraded>,
ws: &mut WebSocketStream<TokioIo<Upgraded>>,
config: &EngineIoConfig,
) -> Result<(), Error> {
let packet = Packet::Open(OpenPacket::new(TransportType::Websocket, sid, config));
Expand Down Expand Up @@ -245,7 +248,7 @@ async fn init_handshake(
async fn upgrade_handshake<H: EngineIoHandler>(
protocol: ProtocolVersion,
socket: &Arc<Socket<H::Data>>,
ws: &mut WebSocketStream<Upgraded>,
ws: &mut WebSocketStream<TokioIo<Upgraded>>,
) -> Result<(), Error> {
debug!("websocket connection upgrade");

Expand Down
Loading

0 comments on commit 2dd0d4a

Please sign in to comment.