diff --git a/benches/end_to_end.rs b/benches/end_to_end.rs index 89c3caf4e2..3558e5c611 100644 --- a/benches/end_to_end.rs +++ b/benches/end_to_end.rs @@ -4,8 +4,7 @@ extern crate test; mod support; -// TODO: Reimplement Opts::bench using hyper::server::conn and hyper::client::conn -// (instead of Server and HttpClient). +// TODO: Reimplement parallel for HTTP/1 use std::convert::Infallible; use std::net::SocketAddr; @@ -315,7 +314,8 @@ impl Opts { let mut client = rt.block_on(async { if self.http2 { - let io = tokio::net::TcpStream::connect(&addr).await.unwrap(); + let tcp = tokio::net::TcpStream::connect(&addr).await.unwrap(); + let io = support::TokioIo::new(tcp); let (tx, conn) = hyper::client::conn::http2::Builder::new(support::TokioExecutor) .initial_stream_window_size(self.http2_stream_window) .initial_connection_window_size(self.http2_conn_window) @@ -328,7 +328,8 @@ impl Opts { } else if self.parallel_cnt > 1 { todo!("http/1 parallel >1"); } else { - let io = tokio::net::TcpStream::connect(&addr).await.unwrap(); + let tcp = tokio::net::TcpStream::connect(&addr).await.unwrap(); + let io = support::TokioIo::new(tcp); let (tx, conn) = hyper::client::conn::http1::Builder::new() .handshake(io) .await @@ -414,6 +415,7 @@ fn spawn_server(rt: &tokio::runtime::Runtime, opts: &Opts) -> SocketAddr { let opts = opts.clone(); rt.spawn(async move { while let Ok((sock, _)) = listener.accept().await { + let io = support::TokioIo::new(sock); if opts.http2 { tokio::spawn( hyper::server::conn::http2::Builder::new(support::TokioExecutor) @@ -421,7 +423,7 @@ fn spawn_server(rt: &tokio::runtime::Runtime, opts: &Opts) -> SocketAddr { .initial_connection_window_size(opts.http2_conn_window) .adaptive_window(opts.http2_adaptive_window) .serve_connection( - sock, + io, service_fn(move |req: Request| async move { let mut req_body = req.into_body(); while let Some(_chunk) = req_body.frame().await {} @@ -433,7 +435,7 @@ fn spawn_server(rt: &tokio::runtime::Runtime, opts: &Opts) -> SocketAddr { ); } else { tokio::spawn(hyper::server::conn::http1::Builder::new().serve_connection( - sock, + io, service_fn(move |req: Request| async move { let mut req_body = req.into_body(); while let Some(_chunk) = req_body.frame().await {} diff --git a/benches/pipeline.rs b/benches/pipeline.rs index a60100fa51..b79232de9b 100644 --- a/benches/pipeline.rs +++ b/benches/pipeline.rs @@ -3,6 +3,8 @@ extern crate test; +mod support; + use std::convert::Infallible; use std::io::{Read, Write}; use std::net::{SocketAddr, TcpStream}; @@ -40,11 +42,12 @@ fn hello_world_16(b: &mut test::Bencher) { rt.spawn(async move { loop { let (stream, _addr) = listener.accept().await.expect("accept"); + let io = support::TokioIo::new(stream); http1::Builder::new() .pipeline_flush(true) .serve_connection( - stream, + io, service_fn(|_| async { Ok::<_, Infallible>(Response::new(Full::new(Bytes::from( "Hello, World!", diff --git a/benches/server.rs b/benches/server.rs index 17eefa0694..c5424105a8 100644 --- a/benches/server.rs +++ b/benches/server.rs @@ -3,6 +3,8 @@ extern crate test; +mod support; + use std::io::{Read, Write}; use std::net::{SocketAddr, TcpListener, TcpStream}; use std::sync::mpsc; @@ -38,10 +40,11 @@ macro_rules! bench_server { rt.spawn(async move { loop { let (stream, _) = listener.accept().await.expect("accept"); + let io = support::TokioIo::new(stream); http1::Builder::new() .serve_connection( - stream, + io, service_fn(|_| async { Ok::<_, hyper::Error>( Response::builder() diff --git a/benches/support/mod.rs b/benches/support/mod.rs index 48e8048e8b..85cb67fd33 100644 --- a/benches/support/mod.rs +++ b/benches/support/mod.rs @@ -1,2 +1,2 @@ mod tokiort; -pub use tokiort::{TokioExecutor, TokioTimer}; +pub use tokiort::{TokioExecutor, TokioIo, TokioTimer}; diff --git a/benches/support/tokiort.rs b/benches/support/tokiort.rs index 4708bd67a1..9a16e0ebad 100644 --- a/benches/support/tokiort.rs +++ b/benches/support/tokiort.rs @@ -88,3 +88,149 @@ impl TokioSleep { self.project().inner.as_mut().reset(deadline.into()); } } + +pin_project! { + #[derive(Debug)] + pub struct TokioIo { + #[pin] + inner: T, + } +} + +impl TokioIo { + pub fn new(inner: T) -> Self { + Self { inner } + } + + pub fn inner(self) -> T { + self.inner + } +} + +impl hyper::rt::Read for TokioIo +where + T: tokio::io::AsyncRead, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + mut buf: hyper::rt::ReadBufCursor<'_>, + ) -> Poll> { + let n = unsafe { + let mut tbuf = tokio::io::ReadBuf::uninit(buf.as_mut()); + match tokio::io::AsyncRead::poll_read(self.project().inner, cx, &mut tbuf) { + Poll::Ready(Ok(())) => tbuf.filled().len(), + other => return other, + } + }; + + unsafe { + buf.advance(n); + } + Poll::Ready(Ok(())) + } +} + +impl hyper::rt::Write for TokioIo +where + T: tokio::io::AsyncWrite, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + tokio::io::AsyncWrite::poll_flush(self.project().inner, cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx) + } + + fn is_write_vectored(&self) -> bool { + tokio::io::AsyncWrite::is_write_vectored(&self.inner) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + tokio::io::AsyncWrite::poll_write_vectored(self.project().inner, cx, bufs) + } +} + +impl tokio::io::AsyncRead for TokioIo +where + T: hyper::rt::Read, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + tbuf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + //let init = tbuf.initialized().len(); + let filled = tbuf.filled().len(); + let sub_filled = unsafe { + let mut buf = hyper::rt::ReadBuf::uninit(tbuf.unfilled_mut()); + + match hyper::rt::Read::poll_read(self.project().inner, cx, buf.unfilled()) { + Poll::Ready(Ok(())) => buf.filled().len(), + other => return other, + } + }; + + let n_filled = filled + sub_filled; + // At least sub_filled bytes had to have been initialized. + let n_init = sub_filled; + unsafe { + tbuf.assume_init(n_init); + tbuf.set_filled(n_filled); + } + + Poll::Ready(Ok(())) + } +} + +impl tokio::io::AsyncWrite for TokioIo +where + T: hyper::rt::Write, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + hyper::rt::Write::poll_write(self.project().inner, cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + hyper::rt::Write::poll_flush(self.project().inner, cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + hyper::rt::Write::poll_shutdown(self.project().inner, cx) + } + + fn is_write_vectored(&self) -> bool { + hyper::rt::Write::is_write_vectored(&self.inner) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + hyper::rt::Write::poll_write_vectored(self.project().inner, cx, bufs) + } +} diff --git a/examples/client.rs b/examples/client.rs index 23f63a7143..1301f278b4 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -7,6 +7,10 @@ use http_body_util::{BodyExt, Empty}; use hyper::Request; use tokio::net::TcpStream; +#[path = "../benches/support/mod.rs"] +mod support; +use support::TokioIo; + // A simple type alias so as to DRY. type Result = std::result::Result>; @@ -39,8 +43,9 @@ async fn fetch_url(url: hyper::Uri) -> Result<()> { let port = url.port_u16().unwrap_or(80); let addr = format!("{}:{}", host, port); let stream = TcpStream::connect(addr).await?; + let io = TokioIo::new(stream); - let (mut sender, conn) = hyper::client::conn::http1::handshake(stream).await?; + let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await?; tokio::task::spawn(async move { if let Err(err) = conn.await { println!("Connection failed: {:?}", err); diff --git a/examples/client_json.rs b/examples/client_json.rs index 9d9c25a7af..9449ec8b37 100644 --- a/examples/client_json.rs +++ b/examples/client_json.rs @@ -7,6 +7,10 @@ use hyper::{body::Buf, Request}; use serde::Deserialize; use tokio::net::TcpStream; +#[path = "../benches/support/mod.rs"] +mod support; +use support::TokioIo; + // A simple type alias so as to DRY. type Result = std::result::Result>; @@ -29,8 +33,9 @@ async fn fetch_json(url: hyper::Uri) -> Result> { let addr = format!("{}:{}", host, port); let stream = TcpStream::connect(addr).await?; + let io = TokioIo::new(stream); - let (mut sender, conn) = hyper::client::conn::http1::handshake(stream).await?; + let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await?; tokio::task::spawn(async move { if let Err(err) = conn.await { println!("Connection failed: {:?}", err); diff --git a/examples/echo.rs b/examples/echo.rs index ff5808da47..52278a5232 100644 --- a/examples/echo.rs +++ b/examples/echo.rs @@ -10,6 +10,10 @@ use hyper::service::service_fn; use hyper::{body::Body, Method, Request, Response, StatusCode}; use tokio::net::TcpListener; +#[path = "../benches/support/mod.rs"] +mod support; +use support::TokioIo; + /// This is our service handler. It receives a Request, routes on its /// path, and returns a Future of a Response. async fn echo( @@ -92,10 +96,11 @@ async fn main() -> Result<(), Box> { println!("Listening on http://{}", addr); loop { let (stream, _) = listener.accept().await?; + let io = TokioIo::new(stream); tokio::task::spawn(async move { if let Err(err) = http1::Builder::new() - .serve_connection(stream, service_fn(echo)) + .serve_connection(io, service_fn(echo)) .await { println!("Error serving connection: {:?}", err); diff --git a/examples/gateway.rs b/examples/gateway.rs index f77a916233..b18d7033c5 100644 --- a/examples/gateway.rs +++ b/examples/gateway.rs @@ -4,6 +4,10 @@ use hyper::{server::conn::http1, service::service_fn}; use std::net::SocketAddr; use tokio::net::{TcpListener, TcpStream}; +#[path = "../benches/support/mod.rs"] +mod support; +use support::TokioIo; + #[tokio::main(flavor="current_thread")] async fn main() -> Result<(), Box> { pretty_env_logger::init(); @@ -20,6 +24,7 @@ async fn main() -> Result<(), Box> { loop { let (stream, _) = listener.accept().await?; + let io = TokioIo::new(stream); // This is the `Service` that will handle the connection. // `service_fn` is a helper to convert a function that @@ -42,9 +47,9 @@ async fn main() -> Result<(), Box> { async move { let client_stream = TcpStream::connect(addr).await.unwrap(); + let io = TokioIo::new(client_stream); - let (mut sender, conn) = - hyper::client::conn::http1::handshake(client_stream).await?; + let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await?; tokio::task::spawn(async move { if let Err(err) = conn.await { println!("Connection failed: {:?}", err); @@ -56,10 +61,7 @@ async fn main() -> Result<(), Box> { }); tokio::task::spawn(async move { - if let Err(err) = http1::Builder::new() - .serve_connection(stream, service) - .await - { + if let Err(err) = http1::Builder::new().serve_connection(io, service).await { println!("Failed to serve the connection: {:?}", err); } }); diff --git a/examples/hello.rs b/examples/hello.rs index e0530bc0e0..d06b1df5a2 100644 --- a/examples/hello.rs +++ b/examples/hello.rs @@ -10,6 +10,10 @@ use hyper::service::service_fn; use hyper::{Request, Response}; use tokio::net::TcpListener; +#[path = "../benches/support/mod.rs"] +mod support; +use support::TokioIo; + // An async function that consumes a request, does nothing with it and returns a // response. async fn hello(_: Request) -> Result>, Infallible> { @@ -35,7 +39,10 @@ pub async fn main() -> Result<(), Box> { // has work to do. In this case, a connection arrives on the port we are listening on and // the task is woken up, at which point the task is then put back on a thread, and is // driven forward by the runtime, eventually yielding a TCP stream. - let (stream, _) = listener.accept().await?; + let (tcp, _) = listener.accept().await?; + // Use an adapter to access something implementing `tokio::io` traits as if they implement + // `hyper::rt` IO traits. + let io = TokioIo::new(tcp); // Spin up a new task in Tokio so we can continue to listen for new TCP connection on the // current task without waiting for the processing of the HTTP1 connection we just received @@ -44,7 +51,7 @@ pub async fn main() -> Result<(), Box> { // Handle the connection from the client using HTTP1 and pass any // HTTP requests received on that connection to the `hello` function if let Err(err) = http1::Builder::new() - .serve_connection(stream, service_fn(hello)) + .serve_connection(io, service_fn(hello)) .await { println!("Error serving connection: {:?}", err); diff --git a/examples/http_proxy.rs b/examples/http_proxy.rs index 0b4a6818b8..c36cc23778 100644 --- a/examples/http_proxy.rs +++ b/examples/http_proxy.rs @@ -12,6 +12,10 @@ use hyper::{Method, Request, Response}; use tokio::net::{TcpListener, TcpStream}; +#[path = "../benches/support/mod.rs"] +mod support; +use support::TokioIo; + // To try this example: // 1. cargo run --example http_proxy // 2. config http_proxy in command line @@ -28,12 +32,13 @@ async fn main() -> Result<(), Box> { loop { let (stream, _) = listener.accept().await?; + let io = TokioIo::new(stream); tokio::task::spawn(async move { if let Err(err) = http1::Builder::new() .preserve_header_case(true) .title_case_headers(true) - .serve_connection(stream, service_fn(proxy)) + .serve_connection(io, service_fn(proxy)) .with_upgrades() .await { @@ -88,11 +93,12 @@ async fn proxy( let addr = format!("{}:{}", host, port); let stream = TcpStream::connect(addr).await.unwrap(); + let io = TokioIo::new(stream); let (mut sender, conn) = Builder::new() .preserve_header_case(true) .title_case_headers(true) - .handshake(stream) + .handshake(io) .await?; tokio::task::spawn(async move { if let Err(err) = conn.await { @@ -123,9 +129,10 @@ fn full>(chunk: T) -> BoxBody { // Create a TCP connection to host:port, build a tunnel between the connection and // the upgraded connection -async fn tunnel(mut upgraded: Upgraded, addr: String) -> std::io::Result<()> { +async fn tunnel(upgraded: Upgraded, addr: String) -> std::io::Result<()> { // Connect to remote server let mut server = TcpStream::connect(addr).await?; + let mut upgraded = TokioIo::new(upgraded); // Proxying data let (from_client, from_server) = diff --git a/examples/multi_server.rs b/examples/multi_server.rs index 0cb9a79efb..0a29848abc 100644 --- a/examples/multi_server.rs +++ b/examples/multi_server.rs @@ -11,6 +11,10 @@ use hyper::service::service_fn; use hyper::{Request, Response}; use tokio::net::TcpListener; +#[path = "../benches/support/mod.rs"] +mod support; +use support::TokioIo; + static INDEX1: &[u8] = b"The 1st service!"; static INDEX2: &[u8] = b"The 2nd service!"; @@ -33,10 +37,11 @@ async fn main() -> Result<(), Box> { let listener = TcpListener::bind(addr1).await.unwrap(); loop { let (stream, _) = listener.accept().await.unwrap(); + let io = TokioIo::new(stream); tokio::task::spawn(async move { if let Err(err) = http1::Builder::new() - .serve_connection(stream, service_fn(index1)) + .serve_connection(io, service_fn(index1)) .await { println!("Error serving connection: {:?}", err); @@ -49,10 +54,11 @@ async fn main() -> Result<(), Box> { let listener = TcpListener::bind(addr2).await.unwrap(); loop { let (stream, _) = listener.accept().await.unwrap(); + let io = TokioIo::new(stream); tokio::task::spawn(async move { if let Err(err) = http1::Builder::new() - .serve_connection(stream, service_fn(index2)) + .serve_connection(io, service_fn(index2)) .await { println!("Error serving connection: {:?}", err); diff --git a/examples/params.rs b/examples/params.rs index 56333b731b..d946e2ee01 100644 --- a/examples/params.rs +++ b/examples/params.rs @@ -13,6 +13,10 @@ use std::convert::Infallible; use std::net::SocketAddr; use url::form_urlencoded; +#[path = "../benches/support/mod.rs"] +mod support; +use support::TokioIo; + static INDEX: &[u8] = b"
Name:
Number:
"; static MISSING: &[u8] = b"Missing field"; static NOTNUMERIC: &[u8] = b"Number field is not numeric"; @@ -124,10 +128,11 @@ async fn main() -> Result<(), Box> { println!("Listening on http://{}", addr); loop { let (stream, _) = listener.accept().await?; + let io = TokioIo::new(stream); tokio::task::spawn(async move { if let Err(err) = http1::Builder::new() - .serve_connection(stream, service_fn(param_example)) + .serve_connection(io, service_fn(param_example)) .await { println!("Error serving connection: {:?}", err); diff --git a/examples/send_file.rs b/examples/send_file.rs index a30d72c2e3..ed821ceb41 100644 --- a/examples/send_file.rs +++ b/examples/send_file.rs @@ -10,6 +10,10 @@ use http_body_util::Full; use hyper::service::service_fn; use hyper::{Method, Request, Response, Result, StatusCode}; +#[path = "../benches/support/mod.rs"] +mod support; +use support::TokioIo; + static INDEX: &str = "examples/send_file_index.html"; static NOTFOUND: &[u8] = b"Not Found"; @@ -24,10 +28,11 @@ async fn main() -> std::result::Result<(), Box> { loop { let (stream, _) = listener.accept().await?; + let io = TokioIo::new(stream); tokio::task::spawn(async move { if let Err(err) = http1::Builder::new() - .serve_connection(stream, service_fn(response_examples)) + .serve_connection(io, service_fn(response_examples)) .await { println!("Failed to serve connection: {:?}", err); diff --git a/examples/service_struct_impl.rs b/examples/service_struct_impl.rs index 1e22b8033f..bb908977f4 100644 --- a/examples/service_struct_impl.rs +++ b/examples/service_struct_impl.rs @@ -10,6 +10,10 @@ use std::net::SocketAddr; use std::pin::Pin; use std::sync::Mutex; +#[path = "../benches/support/mod.rs"] +mod support; +use support::TokioIo; + type Counter = i32; #[tokio::main(flavor="current_thread")] @@ -21,11 +25,12 @@ async fn main() -> Result<(), Box> { loop { let (stream, _) = listener.accept().await?; + let io = TokioIo::new(stream); tokio::task::spawn(async move { if let Err(err) = http1::Builder::new() .serve_connection( - stream, + io, Svc { counter: Mutex::new(81818), }, diff --git a/examples/single_threaded.rs b/examples/single_threaded.rs index de6256239c..83b64ebf21 100644 --- a/examples/single_threaded.rs +++ b/examples/single_threaded.rs @@ -5,7 +5,7 @@ use hyper::server::conn::http2; use std::cell::Cell; use std::net::SocketAddr; use std::rc::Rc; -use tokio::io::{self, AsyncRead, AsyncWrite, AsyncWriteExt}; +use tokio::io::{self, AsyncWriteExt}; use tokio::net::TcpListener; use hyper::body::{Body as HttpBody, Bytes, Frame}; @@ -18,6 +18,10 @@ use std::task::{Context, Poll}; use std::thread; use tokio::net::TcpStream; +#[path = "../benches/support/mod.rs"] +mod support; +use support::TokioIo; + struct Body { // Our Body type is !Send and !Sync: _marker: PhantomData<*const ()>, @@ -98,6 +102,7 @@ async fn server() -> Result<(), Box> { loop { let (stream, _) = listener.accept().await?; + let io = TokioIo::new(stream); // For each connection, clone the counter to use in our service... let cnt = counter.clone(); @@ -111,7 +116,7 @@ async fn server() -> Result<(), Box> { tokio::task::spawn_local(async move { if let Err(err) = http2::Builder::new(LocalExec) - .serve_connection(stream, service) + .serve_connection(io, service) .await { let mut stdout = io::stdout(); @@ -127,11 +132,11 @@ async fn server() -> Result<(), Box> { struct IOTypeNotSend { _marker: PhantomData<*const ()>, - stream: TcpStream, + stream: TokioIo, } impl IOTypeNotSend { - fn new(stream: TcpStream) -> Self { + fn new(stream: TokioIo) -> Self { Self { _marker: PhantomData, stream, @@ -139,7 +144,7 @@ impl IOTypeNotSend { } } -impl AsyncWrite for IOTypeNotSend { +impl hyper::rt::Write for IOTypeNotSend { fn poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -163,11 +168,11 @@ impl AsyncWrite for IOTypeNotSend { } } -impl AsyncRead for IOTypeNotSend { +impl hyper::rt::Read for IOTypeNotSend { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut tokio::io::ReadBuf<'_>, + buf: hyper::rt::ReadBufCursor<'_>, ) -> Poll> { Pin::new(&mut self.stream).poll_read(cx, buf) } @@ -179,7 +184,7 @@ async fn client(url: hyper::Uri) -> Result<(), Box> { let addr = format!("{}:{}", host, port); let stream = TcpStream::connect(addr).await?; - let stream = IOTypeNotSend::new(stream); + let stream = IOTypeNotSend::new(TokioIo::new(stream)); let (mut sender, conn) = hyper::client::conn::http2::handshake(LocalExec, stream).await?; diff --git a/examples/state.rs b/examples/state.rs index 952ed04e4b..a3b2f4a3c5 100644 --- a/examples/state.rs +++ b/examples/state.rs @@ -12,6 +12,10 @@ use hyper::{server::conn::http1, service::service_fn}; use hyper::{Error, Response}; use tokio::net::TcpListener; +#[path = "../benches/support/mod.rs"] +mod support; +use support::TokioIo; + #[tokio::main(flavor="current_thread")] async fn main() -> Result<(), Box> { pretty_env_logger::init(); @@ -26,6 +30,7 @@ async fn main() -> Result<(), Box> { println!("Listening on http://{}", addr); loop { let (stream, _) = listener.accept().await?; + let io = TokioIo::new(stream); // Each connection could send multiple requests, so // the `Service` needs a clone to handle later requests. @@ -46,10 +51,7 @@ async fn main() -> Result<(), Box> { } }); - if let Err(err) = http1::Builder::new() - .serve_connection(stream, service) - .await - { + if let Err(err) = http1::Builder::new().serve_connection(io, service).await { println!("Error serving connection: {:?}", err); } } diff --git a/examples/upgrades.rs b/examples/upgrades.rs index aeae13c210..c66ad3affa 100644 --- a/examples/upgrades.rs +++ b/examples/upgrades.rs @@ -16,11 +16,16 @@ use hyper::service::service_fn; use hyper::upgrade::Upgraded; use hyper::{Request, Response, StatusCode}; +#[path = "../benches/support/mod.rs"] +mod support; +use support::TokioIo; + // A simple type alias so as to DRY. type Result = std::result::Result>; /// Handle server-side I/O after HTTP upgraded. -async fn server_upgraded_io(mut upgraded: Upgraded) -> Result<()> { +async fn server_upgraded_io(upgraded: Upgraded) -> Result<()> { + let mut upgraded = TokioIo::new(upgraded); // we have an upgraded connection that we can read and // write on directly. // @@ -75,7 +80,8 @@ async fn server_upgrade(mut req: Request) -> Result Result<()> { +async fn client_upgraded_io(upgraded: Upgraded) -> Result<()> { + let mut upgraded = TokioIo::new(upgraded); // We've gotten an upgraded connection that we can read // and write directly on. Let's start out 'foobar' protocol. upgraded.write_all(b"foo=bar").await?; @@ -97,7 +103,8 @@ async fn client_upgrade_request(addr: SocketAddr) -> Result<()> { .unwrap(); let stream = TcpStream::connect(addr).await?; - let (mut sender, conn) = hyper::client::conn::http1::handshake(stream).await?; + let io = TokioIo::new(stream); + let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await?; tokio::task::spawn(async move { if let Err(err) = conn.await { @@ -146,10 +153,11 @@ async fn main() { tokio::select! { res = listener.accept() => { let (stream, _) = res.expect("Failed to accept"); + let io = TokioIo::new(stream); let mut rx = rx.clone(); tokio::task::spawn(async move { - let conn = http1::Builder::new().serve_connection(stream, service_fn(server_upgrade)); + let conn = http1::Builder::new().serve_connection(io, service_fn(server_upgrade)); // Don't forget to enable upgrades on the connection. let mut conn = conn.with_upgrades(); diff --git a/examples/web_api.rs b/examples/web_api.rs index e9fd0c5196..8fe16fa585 100644 --- a/examples/web_api.rs +++ b/examples/web_api.rs @@ -9,6 +9,10 @@ use hyper::service::service_fn; use hyper::{body::Incoming as IncomingBody, header, Method, Request, Response, StatusCode}; use tokio::net::{TcpListener, TcpStream}; +#[path = "../benches/support/mod.rs"] +mod support; +use support::TokioIo; + type GenericError = Box; type Result = std::result::Result; type BoxBody = http_body_util::combinators::BoxBody; @@ -30,8 +34,9 @@ async fn client_request_response() -> Result> { let host = req.uri().host().expect("uri has no host"); let port = req.uri().port_u16().expect("uri has no port"); let stream = TcpStream::connect(format!("{}:{}", host, port)).await?; + let io = TokioIo::new(stream); - let (mut sender, conn) = hyper::client::conn::http1::handshake(stream).await?; + let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await?; tokio::task::spawn(async move { if let Err(err) = conn.await { @@ -109,14 +114,12 @@ async fn main() -> Result<()> { println!("Listening on http://{}", addr); loop { let (stream, _) = listener.accept().await?; + let io = TokioIo::new(stream); tokio::task::spawn(async move { let service = service_fn(move |req| response_examples(req)); - if let Err(err) = http1::Builder::new() - .serve_connection(stream, service) - .await - { + if let Err(err) = http1::Builder::new().serve_connection(io, service).await { println!("Failed to serve connection: {:?}", err); } }); diff --git a/src/client/conn/http1.rs b/src/client/conn/http1.rs index cecae92212..a66b42b5b5 100644 --- a/src/client/conn/http1.rs +++ b/src/client/conn/http1.rs @@ -3,10 +3,10 @@ use std::error::Error as StdError; use std::fmt; +use crate::rt::{Read, Write}; use bytes::Bytes; use http::{Request, Response}; use httparse::ParserConfig; -use tokio::io::{AsyncRead, AsyncWrite}; use super::super::dispatch; use crate::body::{Body, Incoming as IncomingBody}; @@ -49,7 +49,7 @@ pub struct Parts { #[must_use = "futures do nothing unless polled"] pub struct Connection where - T: AsyncRead + AsyncWrite + Send + 'static, + T: Read + Write + Send + 'static, B: Body + 'static, { inner: Option>, @@ -57,7 +57,7 @@ where impl Connection where - T: AsyncRead + AsyncWrite + Send + Unpin + 'static, + T: Read + Write + Send + Unpin + 'static, B: Body + 'static, B::Error: Into>, { @@ -114,7 +114,7 @@ pub struct Builder { /// See [`client::conn`](crate::client::conn) for more. pub async fn handshake(io: T) -> crate::Result<(SendRequest, Connection)> where - T: AsyncRead + AsyncWrite + Unpin + Send + 'static, + T: Read + Write + Unpin + Send + 'static, B: Body + 'static, B::Data: Send, B::Error: Into>, @@ -238,7 +238,7 @@ impl fmt::Debug for SendRequest { impl fmt::Debug for Connection where - T: AsyncRead + AsyncWrite + fmt::Debug + Send + 'static, + T: Read + Write + fmt::Debug + Send + 'static, B: Body + 'static, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -248,7 +248,7 @@ where impl Future for Connection where - T: AsyncRead + AsyncWrite + Unpin + Send + 'static, + T: Read + Write + Unpin + Send + 'static, B: Body + Send + 'static, B::Data: Send, B::Error: Into>, @@ -470,7 +470,7 @@ impl Builder { io: T, ) -> impl Future, Connection)>> where - T: AsyncRead + AsyncWrite + Unpin + Send + 'static, + T: Read + Write + Unpin + Send + 'static, B: Body + 'static, B::Data: Send, B::Error: Into>, diff --git a/src/client/conn/http2.rs b/src/client/conn/http2.rs index 16c7af0a3c..c6106d15df 100644 --- a/src/client/conn/http2.rs +++ b/src/client/conn/http2.rs @@ -6,8 +6,8 @@ use std::marker::PhantomData; use std::sync::Arc; use std::time::Duration; +use crate::rt::{Read, Write}; use http::{Request, Response}; -use tokio::io::{AsyncRead, AsyncWrite}; use super::super::dispatch; use crate::body::{Body, Incoming as IncomingBody}; @@ -37,7 +37,7 @@ impl Clone for SendRequest { #[must_use = "futures do nothing unless polled"] pub struct Connection where - T: AsyncRead + AsyncWrite + 'static + Unpin, + T: Read + Write + 'static + Unpin, B: Body + 'static, E: ExecutorClient + Unpin, B::Error: Into>, @@ -64,7 +64,7 @@ pub async fn handshake( io: T, ) -> crate::Result<(SendRequest, Connection)> where - T: AsyncRead + AsyncWrite + Unpin + 'static, + T: Read + Write + Unpin + 'static, B: Body + 'static, B::Data: Send, B::Error: Into>, @@ -193,7 +193,7 @@ impl fmt::Debug for SendRequest { impl Connection where - T: AsyncRead + AsyncWrite + Unpin + 'static, + T: Read + Write + Unpin + 'static, B: Body + Unpin + Send + 'static, B::Data: Send, B::Error: Into>, @@ -215,7 +215,7 @@ where impl fmt::Debug for Connection where - T: AsyncRead + AsyncWrite + fmt::Debug + 'static + Unpin, + T: Read + Write + fmt::Debug + 'static + Unpin, B: Body + 'static, E: ExecutorClient + Unpin, B::Error: Into>, @@ -227,7 +227,7 @@ where impl Future for Connection where - T: AsyncRead + AsyncWrite + Unpin + 'static, + T: Read + Write + Unpin + 'static, B: Body + 'static + Unpin, B::Data: Send, E: Unpin, @@ -398,7 +398,7 @@ where io: T, ) -> impl Future, Connection)>> where - T: AsyncRead + AsyncWrite + Unpin + 'static, + T: Read + Write + Unpin + 'static, B: Body + 'static, B::Data: Send, B::Error: Into>, diff --git a/src/client/conn/mod.rs b/src/client/conn/mod.rs index a70d86e5d3..eda436a8b8 100644 --- a/src/client/conn/mod.rs +++ b/src/client/conn/mod.rs @@ -9,7 +9,9 @@ //! higher-level [Client](super) API. //! //! ## Example -//! A simple example that uses the `SendRequest` struct to talk HTTP over a Tokio TCP stream +//! +//! A simple example that uses the `SendRequest` struct to talk HTTP over some TCP stream. +//! //! ```no_run //! # #[cfg(all(feature = "client", feature = "http1"))] //! # mod rt { @@ -17,38 +19,38 @@ //! use http::{Request, StatusCode}; //! use http_body_util::Empty; //! use hyper::client::conn; -//! use tokio::net::TcpStream; -//! -//! #[tokio::main] -//! async fn main() -> Result<(), Box> { -//! let target_stream = TcpStream::connect("example.com:80").await?; -//! -//! let (mut request_sender, connection) = conn::http1::handshake(target_stream).await?; -//! -//! // spawn a task to poll the connection and drive the HTTP state -//! tokio::spawn(async move { -//! if let Err(e) = connection.await { -//! eprintln!("Error in connection: {}", e); -//! } -//! }); -//! -//! let request = Request::builder() -//! // We need to manually add the host header because SendRequest does not -//! .header("Host", "example.com") -//! .method("GET") -//! .body(Empty::::new())?; -//! let response = request_sender.send_request(request).await?; -//! assert!(response.status() == StatusCode::OK); -//! -//! let request = Request::builder() -//! .header("Host", "example.com") -//! .method("GET") -//! .body(Empty::::new())?; -//! let response = request_sender.send_request(request).await?; -//! assert!(response.status() == StatusCode::OK); -//! Ok(()) -//! } -//! +//! # use hyper::rt::{Read, Write}; +//! # async fn run(tcp: I) -> Result<(), Box> +//! # where +//! # I: Read + Write + Unpin + Send + 'static, +//! # { +//! let (mut request_sender, connection) = conn::http1::handshake(tcp).await?; +//! +//! // spawn a task to poll the connection and drive the HTTP state +//! tokio::spawn(async move { +//! if let Err(e) = connection.await { +//! eprintln!("Error in connection: {}", e); +//! } +//! }); +//! +//! let request = Request::builder() +//! // We need to manually add the host header because SendRequest does not +//! .header("Host", "example.com") +//! .method("GET") +//! .body(Empty::::new())?; +//! +//! let response = request_sender.send_request(request).await?; +//! assert!(response.status() == StatusCode::OK); +//! +//! let request = Request::builder() +//! .header("Host", "example.com") +//! .method("GET") +//! .body(Empty::::new())?; +//! +//! let response = request_sender.send_request(request).await?; +//! assert!(response.status() == StatusCode::OK); +//! # Ok(()) +//! # } //! # } //! ``` diff --git a/src/common/io/compat.rs b/src/common/io/compat.rs new file mode 100644 index 0000000000..3320e4ff44 --- /dev/null +++ b/src/common/io/compat.rs @@ -0,0 +1,150 @@ +use std::pin::Pin; +use std::task::{Context, Poll}; + +/// This adapts from `hyper` IO traits to the ones in Tokio. +/// +/// This is currently used by `h2`, and by hyper internal unit tests. +#[derive(Debug)] +pub(crate) struct Compat(pub(crate) T); + +pub(crate) fn compat(io: T) -> Compat { + Compat(io) +} + +impl Compat { + fn p(self: Pin<&mut Self>) -> Pin<&mut T> { + // SAFETY: The simplest of projections. This is just + // a wrapper, we don't do anything that would undo the projection. + unsafe { self.map_unchecked_mut(|me| &mut me.0) } + } +} + +impl tokio::io::AsyncRead for Compat +where + T: crate::rt::Read, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + tbuf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + let init = tbuf.initialized().len(); + let filled = tbuf.filled().len(); + let (new_init, new_filled) = unsafe { + let mut buf = crate::rt::ReadBuf::uninit(tbuf.inner_mut()); + buf.set_init(init); + buf.set_filled(filled); + + match crate::rt::Read::poll_read(self.p(), cx, buf.unfilled()) { + Poll::Ready(Ok(())) => (buf.init_len(), buf.len()), + other => return other, + } + }; + + let n_init = new_init - init; + unsafe { + tbuf.assume_init(n_init); + tbuf.set_filled(new_filled); + } + + Poll::Ready(Ok(())) + } +} + +impl tokio::io::AsyncWrite for Compat +where + T: crate::rt::Write, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + crate::rt::Write::poll_write(self.p(), cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + crate::rt::Write::poll_flush(self.p(), cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + crate::rt::Write::poll_shutdown(self.p(), cx) + } + + fn is_write_vectored(&self) -> bool { + crate::rt::Write::is_write_vectored(&self.0) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + crate::rt::Write::poll_write_vectored(self.p(), cx, bufs) + } +} + +#[cfg(test)] +impl crate::rt::Read for Compat +where + T: tokio::io::AsyncRead, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + mut buf: crate::rt::ReadBufCursor<'_>, + ) -> Poll> { + let n = unsafe { + let mut tbuf = tokio::io::ReadBuf::uninit(buf.as_mut()); + match tokio::io::AsyncRead::poll_read(self.p(), cx, &mut tbuf) { + Poll::Ready(Ok(())) => tbuf.filled().len(), + other => return other, + } + }; + + unsafe { + buf.advance(n); + } + Poll::Ready(Ok(())) + } +} + +#[cfg(test)] +impl crate::rt::Write for Compat +where + T: tokio::io::AsyncWrite, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + tokio::io::AsyncWrite::poll_write(self.p(), cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + tokio::io::AsyncWrite::poll_flush(self.p(), cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + tokio::io::AsyncWrite::poll_shutdown(self.p(), cx) + } + + fn is_write_vectored(&self) -> bool { + tokio::io::AsyncWrite::is_write_vectored(&self.0) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + tokio::io::AsyncWrite::poll_write_vectored(self.p(), cx, bufs) + } +} diff --git a/src/common/io/mod.rs b/src/common/io/mod.rs index 2e6d506153..6ad07bb771 100644 --- a/src/common/io/mod.rs +++ b/src/common/io/mod.rs @@ -1,3 +1,7 @@ +#[cfg(any(feature = "http2", test))] +mod compat; mod rewind; +#[cfg(any(feature = "http2", test))] +pub(crate) use self::compat::{compat, Compat}; pub(crate) use self::rewind::Rewind; diff --git a/src/common/io/rewind.rs b/src/common/io/rewind.rs index 5642d897d1..f6b6bab3c7 100644 --- a/src/common/io/rewind.rs +++ b/src/common/io/rewind.rs @@ -2,9 +2,9 @@ use std::marker::Unpin; use std::{cmp, io}; use bytes::{Buf, Bytes}; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use crate::common::{task, Pin, Poll}; +use crate::rt::{Read, ReadBufCursor, Write}; /// Combine a buffer with an IO, rewinding reads to use the buffer. #[derive(Debug)] @@ -44,14 +44,14 @@ impl Rewind { // } } -impl AsyncRead for Rewind +impl Read for Rewind where - T: AsyncRead + Unpin, + T: Read + Unpin, { fn poll_read( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, - buf: &mut ReadBuf<'_>, + mut buf: ReadBufCursor<'_>, ) -> Poll> { if let Some(mut prefix) = self.pre.take() { // If there are no remaining bytes, let the bytes get dropped. @@ -72,9 +72,9 @@ where } } -impl AsyncWrite for Rewind +impl Write for Rewind where - T: AsyncWrite + Unpin, + T: Write + Unpin, { fn poll_write( mut self: Pin<&mut Self>, @@ -109,6 +109,7 @@ where mod tests { // FIXME: re-implement tests with `async/await`, this import should // trigger a warning to remind us + use super::super::compat; use super::Rewind; use bytes::Bytes; use tokio::io::AsyncReadExt; @@ -120,14 +121,14 @@ mod tests { let mock = tokio_test::io::Builder::new().read(&underlying).build(); - let mut stream = Rewind::new(mock); + let mut stream = compat(Rewind::new(compat(mock))); // Read off some bytes, ensure we filled o1 let mut buf = [0; 2]; stream.read_exact(&mut buf).await.expect("read1"); // Rewind the stream so that it is as if we never read in the first place. - stream.rewind(Bytes::copy_from_slice(&buf[..])); + stream.0.rewind(Bytes::copy_from_slice(&buf[..])); let mut buf = [0; 5]; stream.read_exact(&mut buf).await.expect("read1"); @@ -143,13 +144,13 @@ mod tests { let mock = tokio_test::io::Builder::new().read(&underlying).build(); - let mut stream = Rewind::new(mock); + let mut stream = compat(Rewind::new(compat(mock))); let mut buf = [0; 5]; stream.read_exact(&mut buf).await.expect("read1"); // Rewind the stream so that it is as if we never read in the first place. - stream.rewind(Bytes::copy_from_slice(&buf[..])); + stream.0.rewind(Bytes::copy_from_slice(&buf[..])); let mut buf = [0; 5]; stream.read_exact(&mut buf).await.expect("read1"); diff --git a/src/ffi/io.rs b/src/ffi/io.rs index bff666dbcf..1d198820a6 100644 --- a/src/ffi/io.rs +++ b/src/ffi/io.rs @@ -2,8 +2,8 @@ use std::ffi::c_void; use std::pin::Pin; use std::task::{Context, Poll}; +use crate::rt::{Read, Write}; use libc::size_t; -use tokio::io::{AsyncRead, AsyncWrite}; use super::task::hyper_context; @@ -120,13 +120,13 @@ extern "C" fn write_noop( 0 } -impl AsyncRead for hyper_io { +impl Read for hyper_io { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut tokio::io::ReadBuf<'_>, + mut buf: crate::rt::ReadBufCursor<'_>, ) -> Poll> { - let buf_ptr = unsafe { buf.unfilled_mut() }.as_mut_ptr() as *mut u8; + let buf_ptr = unsafe { buf.as_mut() }.as_mut_ptr() as *mut u8; let buf_len = buf.remaining(); match (self.read)(self.userdata, hyper_context::wrap(cx), buf_ptr, buf_len) { @@ -138,15 +138,14 @@ impl AsyncRead for hyper_io { ok => { // We have to trust that the user's read callback actually // filled in that many bytes... :( - unsafe { buf.assume_init(ok) }; - buf.advance(ok); + unsafe { buf.advance(ok) }; Poll::Ready(Ok(())) } } } } -impl AsyncWrite for hyper_io { +impl Write for hyper_io { fn poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, diff --git a/src/proto/h1/conn.rs b/src/proto/h1/conn.rs index 563c2662ce..e0d65bd2d4 100644 --- a/src/proto/h1/conn.rs +++ b/src/proto/h1/conn.rs @@ -4,11 +4,11 @@ use std::marker::PhantomData; #[cfg(feature = "server")] use std::time::Duration; +use crate::rt::{Read, Write}; use bytes::{Buf, Bytes}; use http::header::{HeaderValue, CONNECTION}; use http::{HeaderMap, Method, Version}; use httparse::ParserConfig; -use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{debug, error, trace}; use super::io::Buffered; @@ -25,7 +25,7 @@ use crate::rt::Sleep; const H2_PREFACE: &[u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"; /// This handles a connection, which will have been established over an -/// `AsyncRead + AsyncWrite` (like a socket), and will likely include multiple +/// `Read + Write` (like a socket), and will likely include multiple /// `Transaction`s over HTTP. /// /// The connection will determine when a message begins and ends as well as @@ -39,7 +39,7 @@ pub(crate) struct Conn { impl Conn where - I: AsyncRead + AsyncWrite + Unpin, + I: Read + Write + Unpin, B: Buf, T: Http1Transaction, { @@ -1044,12 +1044,13 @@ mod tests { #[bench] fn bench_read_head_short(b: &mut ::test::Bencher) { use super::*; + use crate::common::io::Compat; let s = b"GET / HTTP/1.1\r\nHost: localhost:8080\r\n\r\n"; let len = s.len(); b.bytes = len as u64; // an empty IO, we'll be skipping and using the read buffer anyways - let io = tokio_test::io::Builder::new().build(); + let io = Compat(tokio_test::io::Builder::new().build()); let mut conn = Conn::<_, bytes::Bytes, crate::proto::h1::ServerTransaction>::new(io); *conn.io.read_buf_mut() = ::bytes::BytesMut::from(&s[..]); conn.state.cached_headers = Some(HeaderMap::with_capacity(2)); diff --git a/src/proto/h1/decode.rs b/src/proto/h1/decode.rs index 4077b22062..47d9bbd081 100644 --- a/src/proto/h1/decode.rs +++ b/src/proto/h1/decode.rs @@ -428,9 +428,9 @@ impl StdError for IncompleteBody {} #[cfg(test)] mod tests { use super::*; + use crate::rt::{Read, ReadBuf}; use std::pin::Pin; use std::time::Duration; - use tokio::io::{AsyncRead, ReadBuf}; impl<'a> MemRead for &'a [u8] { fn read_mem(&mut self, _: &mut task::Context<'_>, len: usize) -> Poll> { @@ -446,11 +446,11 @@ mod tests { } } - impl<'a> MemRead for &'a mut (dyn AsyncRead + Unpin) { + impl<'a> MemRead for &'a mut (dyn Read + Unpin) { fn read_mem(&mut self, cx: &mut task::Context<'_>, len: usize) -> Poll> { let mut v = vec![0; len]; let mut buf = ReadBuf::new(&mut v); - ready!(Pin::new(self).poll_read(cx, &mut buf)?); + ready!(Pin::new(self).poll_read(cx, buf.unfilled())?); Poll::Ready(Ok(Bytes::copy_from_slice(&buf.filled()))) } } @@ -629,7 +629,7 @@ mod tests { async fn read_async(mut decoder: Decoder, content: &[u8], block_at: usize) -> String { let mut outs = Vec::new(); - let mut ins = if block_at == 0 { + let mut ins = crate::common::io::compat(if block_at == 0 { tokio_test::io::Builder::new() .wait(Duration::from_millis(10)) .read(content) @@ -640,9 +640,9 @@ mod tests { .wait(Duration::from_millis(10)) .read(&content[block_at..]) .build() - }; + }); - let mut ins = &mut ins as &mut (dyn AsyncRead + Unpin); + let mut ins = &mut ins as &mut (dyn Read + Unpin); loop { let buf = decoder diff --git a/src/proto/h1/dispatch.rs b/src/proto/h1/dispatch.rs index 32ef001f11..eea31a1105 100644 --- a/src/proto/h1/dispatch.rs +++ b/src/proto/h1/dispatch.rs @@ -1,8 +1,8 @@ use std::error::Error as StdError; +use crate::rt::{Read, Write}; use bytes::{Buf, Bytes}; use http::Request; -use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{debug, trace}; use super::{Http1Transaction, Wants}; @@ -64,7 +64,7 @@ where RecvItem = MessageHead, > + Unpin, D::PollError: Into>, - I: AsyncRead + AsyncWrite + Unpin, + I: Read + Write + Unpin, T: Http1Transaction + Unpin, Bs: Body + 'static, Bs::Error: Into>, @@ -97,7 +97,7 @@ where } /// Run this dispatcher until HTTP says this connection is done, - /// but don't call `AsyncWrite::shutdown` on the underlying IO. + /// but don't call `Write::shutdown` on the underlying IO. /// /// This is useful for old-style HTTP upgrades, but ignores /// newer-style upgrade API. @@ -426,7 +426,7 @@ where RecvItem = MessageHead, > + Unpin, D::PollError: Into>, - I: AsyncRead + AsyncWrite + Unpin, + I: Read + Write + Unpin, T: Http1Transaction + Unpin, Bs: Body + 'static, Bs::Error: Into>, @@ -664,6 +664,7 @@ cfg_client! { #[cfg(test)] mod tests { use super::*; + use crate::common::io::compat; use crate::proto::h1::ClientTransaction; use std::time::Duration; @@ -677,7 +678,7 @@ mod tests { // Block at 0 for now, but we will release this response before // the request is ready to write later... let (mut tx, rx) = crate::client::dispatch::channel(); - let conn = Conn::<_, bytes::Bytes, ClientTransaction>::new(io); + let conn = Conn::<_, bytes::Bytes, ClientTransaction>::new(compat(io)); let mut dispatcher = Dispatcher::new(Client::new(rx), conn); // First poll is needed to allow tx to send... @@ -714,7 +715,7 @@ mod tests { .build_with_handle(); let (mut tx, rx) = crate::client::dispatch::channel(); - let mut conn = Conn::<_, bytes::Bytes, ClientTransaction>::new(io); + let mut conn = Conn::<_, bytes::Bytes, ClientTransaction>::new(compat(io)); conn.set_write_strategy_queue(); let dispatcher = Dispatcher::new(Client::new(rx), conn); @@ -745,7 +746,7 @@ mod tests { .build(); let (mut tx, rx) = crate::client::dispatch::channel(); - let conn = Conn::<_, bytes::Bytes, ClientTransaction>::new(io); + let conn = Conn::<_, bytes::Bytes, ClientTransaction>::new(compat(io)); let mut dispatcher = tokio_test::task::spawn(Dispatcher::new(Client::new(rx), conn)); // First poll is needed to allow tx to send... diff --git a/src/proto/h1/io.rs b/src/proto/h1/io.rs index da4101b6fb..b49cda3dd3 100644 --- a/src/proto/h1/io.rs +++ b/src/proto/h1/io.rs @@ -6,8 +6,8 @@ use std::io::{self, IoSlice}; use std::marker::Unpin; use std::mem::MaybeUninit; +use crate::rt::{Read, ReadBuf, Write}; use bytes::{Buf, BufMut, Bytes, BytesMut}; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tracing::{debug, trace}; use super::{Http1Transaction, ParseContext, ParsedMessage}; @@ -55,7 +55,7 @@ where impl Buffered where - T: AsyncRead + AsyncWrite + Unpin, + T: Read + Write + Unpin, B: Buf, { pub(crate) fn new(io: T) -> Buffered { @@ -251,7 +251,7 @@ where let dst = self.read_buf.chunk_mut(); let dst = unsafe { &mut *(dst as *mut _ as *mut [MaybeUninit]) }; let mut buf = ReadBuf::uninit(dst); - match Pin::new(&mut self.io).poll_read(cx, &mut buf) { + match Pin::new(&mut self.io).poll_read(cx, buf.unfilled()) { Poll::Ready(Ok(_)) => { let n = buf.filled().len(); trace!("received {} bytes", n); @@ -359,7 +359,7 @@ pub(crate) trait MemRead { impl MemRead for Buffered where - T: AsyncRead + AsyncWrite + Unpin, + T: Read + Write + Unpin, B: Buf, { fn read_mem(&mut self, cx: &mut task::Context<'_>, len: usize) -> Poll> { @@ -662,6 +662,7 @@ enum WriteStrategy { #[cfg(test)] mod tests { + use crate::common::io::compat; use crate::common::time::Time; use super::*; @@ -717,7 +718,7 @@ mod tests { .wait(Duration::from_secs(1)) .build(); - let mut buffered = Buffered::<_, Cursor>>::new(mock); + let mut buffered = Buffered::<_, Cursor>>::new(compat(mock)); // We expect a `parse` to be not ready, and so can't await it directly. // Rather, this `poll_fn` will wrap the `Poll` result. @@ -862,7 +863,7 @@ mod tests { #[cfg(debug_assertions)] // needs to trigger a debug_assert fn write_buf_requires_non_empty_bufs() { let mock = Mock::new().build(); - let mut buffered = Buffered::<_, Cursor>>::new(mock); + let mut buffered = Buffered::<_, Cursor>>::new(compat(mock)); buffered.buffer(Cursor::new(Vec::new())); } @@ -897,7 +898,7 @@ mod tests { let mock = Mock::new().write(b"hello world, it's hyper!").build(); - let mut buffered = Buffered::<_, Cursor>>::new(mock); + let mut buffered = Buffered::<_, Cursor>>::new(compat(mock)); buffered.write_buf.set_strategy(WriteStrategy::Flatten); buffered.headers_buf().extend(b"hello "); @@ -956,7 +957,7 @@ mod tests { .write(b"hyper!") .build(); - let mut buffered = Buffered::<_, Cursor>>::new(mock); + let mut buffered = Buffered::<_, Cursor>>::new(compat(mock)); buffered.write_buf.set_strategy(WriteStrategy::Queue); // we have 4 buffers, and vec IO disabled, but explicitly said diff --git a/src/proto/h2/client.rs b/src/proto/h2/client.rs index 56aff85a9f..b8d9951928 100644 --- a/src/proto/h2/client.rs +++ b/src/proto/h2/client.rs @@ -2,6 +2,7 @@ use std::marker::PhantomData; use std::time::Duration; +use crate::rt::{Read, Write}; use bytes::Bytes; use futures_channel::mpsc::{Receiver, Sender}; use futures_channel::{mpsc, oneshot}; @@ -11,13 +12,13 @@ use h2::client::{Builder, Connection, SendRequest}; use h2::SendStream; use http::{Method, StatusCode}; use pin_project_lite::pin_project; -use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{debug, trace, warn}; use super::ping::{Ponger, Recorder}; use super::{ping, H2Upgraded, PipeToSendStream, SendBuf}; use crate::body::{Body, Incoming as IncomingBody}; use crate::client::dispatch::{Callback, SendWhen}; +use crate::common::io::Compat; use crate::common::time::Time; use crate::common::{task, Future, Never, Pin, Poll}; use crate::ext::Protocol; @@ -111,14 +112,14 @@ pub(crate) async fn handshake( timer: Time, ) -> crate::Result> where - T: AsyncRead + AsyncWrite + Unpin + 'static, + T: Read + Write + Unpin + 'static, B: Body + 'static, B::Data: Send + 'static, E: ExecutorClient + Unpin, B::Error: Into>, { let (h2_tx, mut conn) = new_builder(config) - .handshake::<_, SendBuf>(io) + .handshake::<_, SendBuf>(crate::common::io::compat(io)) .await .map_err(crate::Error::new_h2)?; @@ -168,16 +169,16 @@ pin_project! { #[pin] ponger: Ponger, #[pin] - conn: Connection::Data>>, + conn: Connection, SendBuf<::Data>>, } } impl Conn where B: Body, - T: AsyncRead + AsyncWrite + Unpin, + T: Read + Write + Unpin, { - fn new(ponger: Ponger, conn: Connection::Data>>) -> Self { + fn new(ponger: Ponger, conn: Connection, SendBuf<::Data>>) -> Self { Conn { ponger, conn } } } @@ -185,7 +186,7 @@ where impl Future for Conn where B: Body, - T: AsyncRead + AsyncWrite + Unpin, + T: Read + Write + Unpin, { type Output = Result<(), h2::Error>; @@ -211,19 +212,19 @@ pin_project! { struct ConnMapErr where B: Body, - T: AsyncRead, - T: AsyncWrite, + T: Read, + T: Write, T: Unpin, { #[pin] - conn: Either, Connection::Data>>>, + conn: Either, Connection, SendBuf<::Data>>>, } } impl Future for ConnMapErr where B: Body, - T: AsyncRead + AsyncWrite + Unpin, + T: Read + Write + Unpin, { type Output = Result<(), ()>; @@ -239,8 +240,8 @@ pin_project! { pub struct ConnTask where B: Body, - T: AsyncRead, - T: AsyncWrite, + T: Read, + T: Write, T: Unpin, { #[pin] @@ -254,7 +255,7 @@ pin_project! { impl ConnTask where B: Body, - T: AsyncRead + AsyncWrite + Unpin, + T: Read + Write + Unpin, { fn new( conn: ConnMapErr, @@ -272,7 +273,7 @@ where impl Future for ConnTask where B: Body, - T: AsyncRead + AsyncWrite + Unpin, + T: Read + Write + Unpin, { type Output = (); @@ -308,8 +309,8 @@ pin_project! { B: http_body::Body, B: 'static, B::Error: Into>, - T: AsyncRead, - T: AsyncWrite, + T: Read, + T: Write, T: Unpin, { Pipe { @@ -331,7 +332,7 @@ impl Future for H2ClientFuture where B: http_body::Body + 'static, B::Error: Into>, - T: AsyncRead + AsyncWrite + Unpin, + T: Read + Write + Unpin, { type Output = (); @@ -383,7 +384,7 @@ where B: Body + 'static, E: ExecutorClient + Unpin, B::Error: Into>, - T: AsyncRead + AsyncWrite + Unpin, + T: Read + Write + Unpin, { pub(crate) fn is_extended_connect_protocol_enabled(&self) -> bool { self.h2_tx.is_extended_connect_protocol_enabled() @@ -438,7 +439,7 @@ where B::Data: Send, E: ExecutorClient + Unpin, B::Error: Into>, - T: AsyncRead + AsyncWrite + Unpin, + T: Read + Write + Unpin, { fn poll_pipe(&mut self, f: FutCtx, cx: &mut task::Context<'_>) { let ping = self.ping.clone(); @@ -573,7 +574,7 @@ where B::Data: Send, B::Error: Into>, E: ExecutorClient + 'static + Send + Sync + Unpin, - T: AsyncRead + AsyncWrite + Unpin, + T: Read + Write + Unpin, { type Output = crate::Result; diff --git a/src/proto/h2/mod.rs b/src/proto/h2/mod.rs index d0e8c0c323..2002edeb13 100644 --- a/src/proto/h2/mod.rs +++ b/src/proto/h2/mod.rs @@ -1,13 +1,13 @@ +use crate::rt::{Read, ReadBufCursor, Write}; use bytes::{Buf, Bytes}; use h2::{Reason, RecvStream, SendStream}; use http::header::{HeaderName, CONNECTION, TE, TRAILER, TRANSFER_ENCODING, UPGRADE}; use http::HeaderMap; use pin_project_lite::pin_project; use std::error::Error as StdError; -use std::io::{self, Cursor, IoSlice}; +use std::io::{Cursor, IoSlice}; use std::mem; use std::task::Context; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tracing::{debug, trace, warn}; use crate::body::Body; @@ -271,15 +271,15 @@ where buf: Bytes, } -impl AsyncRead for H2Upgraded +impl Read for H2Upgraded where B: Buf, { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - read_buf: &mut ReadBuf<'_>, - ) -> Poll> { + mut read_buf: ReadBufCursor<'_>, + ) -> Poll> { if self.buf.is_empty() { self.buf = loop { match ready!(self.recv_stream.poll_data(cx)) { @@ -295,7 +295,7 @@ where return Poll::Ready(match e.reason() { Some(Reason::NO_ERROR) | Some(Reason::CANCEL) => Ok(()), Some(Reason::STREAM_CLOSED) => { - Err(io::Error::new(io::ErrorKind::BrokenPipe, e)) + Err(std::io::Error::new(std::io::ErrorKind::BrokenPipe, e)) } _ => Err(h2_to_io_error(e)), }) @@ -311,7 +311,7 @@ where } } -impl AsyncWrite for H2Upgraded +impl Write for H2Upgraded where B: Buf, { @@ -319,7 +319,7 @@ where mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], - ) -> Poll> { + ) -> Poll> { if buf.is_empty() { return Poll::Ready(Ok(0)); } @@ -344,7 +344,7 @@ where Poll::Ready(Err(h2_to_io_error( match ready!(self.send_stream.poll_reset(cx)) { Ok(Reason::NO_ERROR) | Ok(Reason::CANCEL) | Ok(Reason::STREAM_CLOSED) => { - return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into())) + return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())) } Ok(reason) => reason.into(), Err(e) => e, @@ -352,14 +352,14 @@ where ))) } - fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } fn poll_shutdown( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll> { + ) -> Poll> { if self.send_stream.write(&[], true).is_ok() { return Poll::Ready(Ok(())); } @@ -368,7 +368,7 @@ where match ready!(self.send_stream.poll_reset(cx)) { Ok(Reason::NO_ERROR) => return Poll::Ready(Ok(())), Ok(Reason::CANCEL) | Ok(Reason::STREAM_CLOSED) => { - return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into())) + return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())) } Ok(reason) => reason.into(), Err(e) => e, @@ -377,11 +377,11 @@ where } } -fn h2_to_io_error(e: h2::Error) -> io::Error { +fn h2_to_io_error(e: h2::Error) -> std::io::Error { if e.is_io() { e.into_io().unwrap() } else { - io::Error::new(io::ErrorKind::Other, e) + std::io::Error::new(std::io::ErrorKind::Other, e) } } @@ -408,7 +408,7 @@ where unsafe { self.as_inner_unchecked().poll_reset(cx) } } - fn write(&mut self, buf: &[u8], end_of_stream: bool) -> Result<(), io::Error> { + fn write(&mut self, buf: &[u8], end_of_stream: bool) -> Result<(), std::io::Error> { let send_buf = SendBuf::Cursor(Cursor::new(buf.into())); unsafe { self.as_inner_unchecked() diff --git a/src/proto/h2/server.rs b/src/proto/h2/server.rs index bf458f428c..0913f314c9 100644 --- a/src/proto/h2/server.rs +++ b/src/proto/h2/server.rs @@ -3,12 +3,12 @@ use std::marker::Unpin; use std::time::Duration; +use crate::rt::{Read, Write}; use bytes::Bytes; use h2::server::{Connection, Handshake, SendResponse}; use h2::{Reason, RecvStream}; use http::{Method, Request}; use pin_project_lite::pin_project; -use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{debug, trace, warn}; use super::{ping, PipeToSendStream, SendBuf}; @@ -89,7 +89,7 @@ where { Handshaking { ping_config: ping::Config, - hs: Handshake>, + hs: Handshake, SendBuf>, }, Serving(Serving), Closed, @@ -100,13 +100,13 @@ where B: Body, { ping: Option<(ping::Recorder, ping::Ponger)>, - conn: Connection>, + conn: Connection, SendBuf>, closing: Option, } impl Server where - T: AsyncRead + AsyncWrite + Unpin, + T: Read + Write + Unpin, S: HttpService, S::Error: Into>, B: Body + 'static, @@ -132,7 +132,7 @@ where if config.enable_connect_protocol { builder.enable_connect_protocol(); } - let handshake = builder.handshake(io); + let handshake = builder.handshake(crate::common::io::compat(io)); let bdp = if config.adaptive_window { Some(config.initial_stream_window_size) @@ -182,7 +182,7 @@ where impl Future for Server where - T: AsyncRead + AsyncWrite + Unpin, + T: Read + Write + Unpin, S: HttpService, S::Error: Into>, B: Body + 'static, @@ -228,7 +228,7 @@ where impl Serving where - T: AsyncRead + AsyncWrite + Unpin, + T: Read + Write + Unpin, B: Body + 'static, { fn poll_server( diff --git a/src/rt/bounds.rs b/src/rt/bounds.rs index 6368339796..36f3683ead 100644 --- a/src/rt/bounds.rs +++ b/src/rt/bounds.rs @@ -13,8 +13,8 @@ pub use self::h2_client::ExecutorClient; #[cfg_attr(docsrs, doc(cfg(all(feature = "server", feature = "http2"))))] mod h2_client { use std::{error::Error, future::Future}; - use tokio::io::{AsyncRead, AsyncWrite}; + use crate::rt::{Read, Write}; use crate::{proto::h2::client::H2ClientFuture, rt::Executor}; /// An executor to spawn http2 futures for the client. @@ -29,7 +29,7 @@ mod h2_client { where B: http_body::Body, B::Error: Into>, - T: AsyncRead + AsyncWrite + Unpin, + T: Read + Write + Unpin, { #[doc(hidden)] fn execute_h2_future(&mut self, future: H2ClientFuture); @@ -41,7 +41,7 @@ mod h2_client { B: http_body::Body + 'static, B::Error: Into>, H2ClientFuture: Future, - T: AsyncRead + AsyncWrite + Unpin, + T: Read + Write + Unpin, { fn execute_h2_future(&mut self, future: H2ClientFuture) { self.execute(future) @@ -54,7 +54,7 @@ mod h2_client { B: http_body::Body + 'static, B::Error: Into>, H2ClientFuture: Future, - T: AsyncRead + AsyncWrite + Unpin, + T: Read + Write + Unpin, { } diff --git a/src/rt/io.rs b/src/rt/io.rs new file mode 100644 index 0000000000..c39e1e098d --- /dev/null +++ b/src/rt/io.rs @@ -0,0 +1,334 @@ +use std::fmt; +use std::mem::MaybeUninit; +use std::pin::Pin; +use std::task::{Context, Poll}; + +// New IO traits? What?! Why, are you bonkers? +// +// I mean, yes, probably. But, here's the goals: +// +// 1. Supports poll-based IO operations. +// 2. Opt-in vectored IO. +// 3. Can use an optional buffer pool. +// 4. Able to add completion-based (uring) IO eventually. +// +// Frankly, the last point is the entire reason we're doing this. We want to +// have forwards-compatibility with an eventually stable io-uring runtime. We +// don't need that to work right away. But it must be possible to add in here +// without breaking hyper 1.0. +// +// While in here, if there's small tweaks to poll_read or poll_write that would +// allow even the "slow" path to be faster, such as if someone didn't remember +// to forward along an `is_completion` call. + +/// Reads bytes from a source. +/// +/// This trait is similar to `std::io::Read`, but supports asynchronous reads. +pub trait Read { + /// Attempts to read bytes into the `buf`. + /// + /// On success, returns `Poll::Ready(Ok(()))` and places data in the + /// unfilled portion of `buf`. If no data was read (`buf.remaining()` is + /// unchanged), it implies that EOF has been reached. + /// + /// If no data is available for reading, the method returns `Poll::Pending` + /// and arranges for the current task (via `cx.waker()`) to receive a + /// notification when the object becomes readable or is closed. + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: ReadBufCursor<'_>, + ) -> Poll>; +} + +/// Write bytes asynchronously. +/// +/// This trait is similar to `std::io::Write`, but for asynchronous writes. +pub trait Write { + /// Attempt to write bytes from `buf` into the destination. + /// + /// On success, returns `Poll::Ready(Ok(num_bytes_written)))`. If + /// successful, it must be guaranteed that `n <= buf.len()`. A return value + /// of `0` means that the underlying object is no longer able to accept + /// bytes, or that the provided buffer is empty. + /// + /// If the object is not ready for writing, the method returns + /// `Poll::Pending` and arranges for the current task (via `cx.waker()`) to + /// receive a notification when the object becomes writable or is closed. + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll>; + + /// Attempts to flush the object. + /// + /// On success, returns `Poll::Ready(Ok(()))`. + /// + /// If flushing cannot immediately complete, this method returns + /// `Poll::Pending` and arranges for the current task (via `cx.waker()`) to + /// receive a notification when the object can make progress. + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>; + + /// Attempts to shut down this writer. + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>; + + /// Returns whether this writer has an efficient `poll_write_vectored` + /// implementation. + /// + /// The default implementation returns `false`. + fn is_write_vectored(&self) -> bool { + false + } + + /// Like `poll_write`, except that it writes from a slice of buffers. + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + let buf = bufs + .iter() + .find(|b| !b.is_empty()) + .map_or(&[][..], |b| &**b); + self.poll_write(cx, buf) + } +} + +/// A wrapper around a byte buffer that is incrementally filled and initialized. +/// +/// This type is a sort of "double cursor". It tracks three regions in the +/// buffer: a region at the beginning of the buffer that has been logically +/// filled with data, a region that has been initialized at some point but not +/// yet logically filled, and a region at the end that may be uninitialized. +/// The filled region is guaranteed to be a subset of the initialized region. +/// +/// In summary, the contents of the buffer can be visualized as: +/// +/// ```not_rust +/// [ capacity ] +/// [ filled | unfilled ] +/// [ initialized | uninitialized ] +/// ``` +/// +/// It is undefined behavior to de-initialize any bytes from the uninitialized +/// region, since it is merely unknown whether this region is uninitialized or +/// not, and if part of it turns out to be initialized, it must stay initialized. +pub struct ReadBuf<'a> { + raw: &'a mut [MaybeUninit], + filled: usize, + init: usize, +} + +/// The cursor part of a [`ReadBuf`]. +/// +/// This is created by calling `ReadBuf::unfilled()`. +#[derive(Debug)] +pub struct ReadBufCursor<'a> { + buf: &'a mut ReadBuf<'a>, +} + +impl<'data> ReadBuf<'data> { + #[inline] + #[cfg(test)] + pub(crate) fn new(raw: &'data mut [u8]) -> Self { + let len = raw.len(); + Self { + // SAFETY: We never de-init the bytes ourselves. + raw: unsafe { &mut *(raw as *mut [u8] as *mut [MaybeUninit]) }, + filled: 0, + init: len, + } + } + + /// Create a new `ReadBuf` with a slice of uninitialized bytes. + #[inline] + pub fn uninit(raw: &'data mut [MaybeUninit]) -> Self { + Self { + raw, + filled: 0, + init: 0, + } + } + + /// Get a slice of the buffer that has been filled in with bytes. + #[inline] + pub fn filled(&self) -> &[u8] { + // SAFETY: We only slice the filled part of the buffer, which is always valid + unsafe { &*(&self.raw[0..self.filled] as *const [MaybeUninit] as *const [u8]) } + } + + /// Get a cursor to the unfilled portion of the buffer. + #[inline] + pub fn unfilled<'cursor>(&'cursor mut self) -> ReadBufCursor<'cursor> { + ReadBufCursor { + // SAFETY: self.buf is never re-assigned, so its safe to narrow + // the lifetime. + buf: unsafe { + std::mem::transmute::<&'cursor mut ReadBuf<'data>, &'cursor mut ReadBuf<'cursor>>( + self, + ) + }, + } + } + + #[inline] + pub(crate) unsafe fn set_init(&mut self, n: usize) { + self.init = self.init.max(n); + } + + #[inline] + pub(crate) unsafe fn set_filled(&mut self, n: usize) { + self.filled = self.filled.max(n); + } + + #[inline] + pub(crate) fn len(&self) -> usize { + self.filled + } + + #[inline] + pub(crate) fn init_len(&self) -> usize { + self.init + } + + #[inline] + fn remaining(&self) -> usize { + self.capacity() - self.filled + } + + #[inline] + fn capacity(&self) -> usize { + self.raw.len() + } +} + +impl<'data> fmt::Debug for ReadBuf<'data> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ReadBuf") + .field("filled", &self.filled) + .field("init", &self.init) + .field("capacity", &self.capacity()) + .finish() + } +} + +impl<'data> ReadBufCursor<'data> { + /// Access the unfilled part of the buffer. + /// + /// # Safety + /// + /// The caller must not uninitialize any bytes that may have been + /// initialized before. + #[inline] + pub unsafe fn as_mut(&mut self) -> &mut [MaybeUninit] { + &mut self.buf.raw[self.buf.filled..] + } + + /// Advance the `filled` cursor by `n` bytes. + /// + /// # Safety + /// + /// The caller must take care that `n` more bytes have been initialized. + #[inline] + pub unsafe fn advance(&mut self, n: usize) { + self.buf.filled = self.buf.filled.checked_add(n).expect("overflow"); + self.buf.init = self.buf.filled.max(self.buf.init); + } + + #[inline] + pub(crate) fn remaining(&self) -> usize { + self.buf.remaining() + } + + #[inline] + pub(crate) fn put_slice(&mut self, buf: &[u8]) { + assert!( + self.buf.remaining() >= buf.len(), + "buf.len() must fit in remaining()" + ); + + let amt = buf.len(); + // Cannot overflow, asserted above + let end = self.buf.filled + amt; + + // Safety: the length is asserted above + unsafe { + self.buf.raw[self.buf.filled..end] + .as_mut_ptr() + .cast::() + .copy_from_nonoverlapping(buf.as_ptr(), amt); + } + + if self.buf.init < end { + self.buf.init = end; + } + self.buf.filled = end; + } +} + +macro_rules! deref_async_read { + () => { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: ReadBufCursor<'_>, + ) -> Poll> { + Pin::new(&mut **self).poll_read(cx, buf) + } + }; +} + +impl Read for Box { + deref_async_read!(); +} + +impl Read for &mut T { + deref_async_read!(); +} + +macro_rules! deref_async_write { + () => { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut **self).poll_write(cx, buf) + } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + Pin::new(&mut **self).poll_write_vectored(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + (**self).is_write_vectored() + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut **self).poll_flush(cx) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(&mut **self).poll_shutdown(cx) + } + }; +} + +impl Write for Box { + deref_async_write!(); +} + +impl Write for &mut T { + deref_async_write!(); +} diff --git a/src/rt/mod.rs b/src/rt/mod.rs index 82854a546c..de67c3fc89 100644 --- a/src/rt/mod.rs +++ b/src/rt/mod.rs @@ -1,14 +1,18 @@ //! Runtime components //! -//! By default, hyper includes the [tokio](https://tokio.rs) runtime. +//! The traits and types within this module are used to allow plugging in +//! runtime types. These include: //! -//! If the `runtime` feature is disabled, the types in this module can be used -//! to plug in other runtimes. +//! - Executors +//! - Timers +//! - IO transports pub mod bounds; +mod io; mod timer; -pub use timer::{Sleep, Timer}; +pub use self::io::{Read, ReadBuf, ReadBufCursor, Write}; +pub use self::timer::{Sleep, Timer}; /// An executor of futures. /// diff --git a/src/server/conn/http1.rs b/src/server/conn/http1.rs index 530082e966..09770cd3cb 100644 --- a/src/server/conn/http1.rs +++ b/src/server/conn/http1.rs @@ -5,8 +5,8 @@ use std::fmt; use std::sync::Arc; use std::time::Duration; +use crate::rt::{Read, Write}; use bytes::Bytes; -use tokio::io::{AsyncRead, AsyncWrite}; use crate::body::{Body, Incoming as IncomingBody}; use crate::common::{task, Future, Pin, Poll, Unpin}; @@ -85,7 +85,7 @@ impl Connection where S: HttpService, S::Error: Into>, - I: AsyncRead + AsyncWrite + Unpin, + I: Read + Write + Unpin, B: Body + 'static, B::Error: Into>, { @@ -172,7 +172,7 @@ impl Future for Connection where S: HttpService, S::Error: Into>, - I: AsyncRead + AsyncWrite + Unpin + 'static, + I: Read + Write + Unpin + 'static, B: Body + 'static, B::Error: Into>, { @@ -333,10 +333,10 @@ impl Builder { /// # use hyper::{body::Incoming, Request, Response}; /// # use hyper::service::Service; /// # use hyper::server::conn::http1::Builder; - /// # use tokio::io::{AsyncRead, AsyncWrite}; + /// # use hyper::rt::{Read, Write}; /// # async fn run(some_io: I, some_service: S) /// # where - /// # I: AsyncRead + AsyncWrite + Unpin + Send + 'static, + /// # I: Read + Write + Unpin + Send + 'static, /// # S: Service, Response=hyper::Response> + Send + 'static, /// # S::Error: Into>, /// # S::Future: Send, @@ -356,7 +356,7 @@ impl Builder { S::Error: Into>, S::ResBody: 'static, ::Error: Into>, - I: AsyncRead + AsyncWrite + Unpin, + I: Read + Write + Unpin, { let mut conn = proto::Conn::new(io); conn.set_timer(self.timer.clone()); @@ -413,7 +413,7 @@ mod upgrades { where S: HttpService, S::Error: Into>, - I: AsyncRead + AsyncWrite + Unpin, + I: Read + Write + Unpin, B: Body + 'static, B::Error: Into>, { @@ -430,7 +430,7 @@ mod upgrades { where S: HttpService, S::Error: Into>, - I: AsyncRead + AsyncWrite + Unpin + Send + 'static, + I: Read + Write + Unpin + Send + 'static, B: Body + 'static, B::Error: Into>, { diff --git a/src/server/conn/http2.rs b/src/server/conn/http2.rs index e1345f3b6b..22a7cdcff1 100644 --- a/src/server/conn/http2.rs +++ b/src/server/conn/http2.rs @@ -5,8 +5,8 @@ use std::fmt; use std::sync::Arc; use std::time::Duration; +use crate::rt::{Read, Write}; use pin_project_lite::pin_project; -use tokio::io::{AsyncRead, AsyncWrite}; use crate::body::{Body, Incoming as IncomingBody}; use crate::common::{task, Future, Pin, Poll, Unpin}; @@ -51,7 +51,7 @@ impl Connection where S: HttpService, S::Error: Into>, - I: AsyncRead + AsyncWrite + Unpin, + I: Read + Write + Unpin, B: Body + 'static, B::Error: Into>, E: Http2ConnExec, @@ -75,7 +75,7 @@ impl Future for Connection where S: HttpService, S::Error: Into>, - I: AsyncRead + AsyncWrite + Unpin + 'static, + I: Read + Write + Unpin + 'static, B: Body + 'static, B::Error: Into>, E: Http2ConnExec, @@ -255,7 +255,7 @@ impl Builder { S::Error: Into>, Bd: Body + 'static, Bd::Error: Into>, - I: AsyncRead + AsyncWrite + Unpin, + I: Read + Write + Unpin, E: Http2ConnExec, { let proto = proto::h2::Server::new( diff --git a/src/server/conn/mod.rs b/src/server/conn/mod.rs index f2abae22aa..b7dea1b8c6 100644 --- a/src/server/conn/mod.rs +++ b/src/server/conn/mod.rs @@ -7,43 +7,6 @@ //! //! This module is split by HTTP version. Both work similarly, but do have //! specific options on each builder. -//! -//! ## Example -//! -//! A simple example that prepares an HTTP/1 connection over a Tokio TCP stream. -//! -//! ```no_run -//! # #[cfg(feature = "http1")] -//! # mod rt { -//! use http::{Request, Response, StatusCode}; -//! use http_body_util::Full; -//! use hyper::{server::conn::http1, service::service_fn, body, body::Bytes}; -//! use std::{net::SocketAddr, convert::Infallible}; -//! use tokio::net::TcpListener; -//! -//! #[tokio::main] -//! async fn main() -> Result<(), Box> { -//! let addr: SocketAddr = ([127, 0, 0, 1], 8080).into(); -//! -//! let mut tcp_listener = TcpListener::bind(addr).await?; -//! loop { -//! let (tcp_stream, _) = tcp_listener.accept().await?; -//! tokio::task::spawn(async move { -//! if let Err(http_err) = http1::Builder::new() -//! .keep_alive(true) -//! .serve_connection(tcp_stream, service_fn(hello)) -//! .await { -//! eprintln!("Error while serving HTTP connection: {}", http_err); -//! } -//! }); -//! } -//! } -//! -//! async fn hello(_req: Request) -> Result>, Infallible> { -//! Ok(Response::new(Full::new(Bytes::from("Hello World!")))) -//! } -//! # } -//! ``` #[cfg(feature = "http1")] pub mod http1; diff --git a/src/upgrade.rs b/src/upgrade.rs index 1c7b5b01cd..231578f913 100644 --- a/src/upgrade.rs +++ b/src/upgrade.rs @@ -45,8 +45,8 @@ use std::fmt; use std::io; use std::marker::Unpin; +use crate::rt::{Read, ReadBufCursor, Write}; use bytes::Bytes; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio::sync::oneshot; #[cfg(any(feature = "http1", feature = "http2"))] use tracing::trace; @@ -122,7 +122,7 @@ impl Upgraded { #[cfg(any(feature = "http1", feature = "http2", test))] pub(super) fn new(io: T, read_buf: Bytes) -> Self where - T: AsyncRead + AsyncWrite + Unpin + Send + 'static, + T: Read + Write + Unpin + Send + 'static, { Upgraded { io: Rewind::new_buffered(Box::new(io), read_buf), @@ -133,7 +133,7 @@ impl Upgraded { /// /// On success, returns the downcasted parts. On error, returns the /// `Upgraded` back. - pub fn downcast(self) -> Result, Self> { + pub fn downcast(self) -> Result, Self> { let (io, buf) = self.io.into_inner(); match io.__hyper_downcast() { Ok(t) => Ok(Parts { @@ -148,17 +148,17 @@ impl Upgraded { } } -impl AsyncRead for Upgraded { +impl Read for Upgraded { fn poll_read( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, - buf: &mut ReadBuf<'_>, + buf: ReadBufCursor<'_>, ) -> Poll> { Pin::new(&mut self.io).poll_read(cx, buf) } } -impl AsyncWrite for Upgraded { +impl Write for Upgraded { fn poll_write( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, @@ -265,13 +265,13 @@ impl StdError for UpgradeExpected {} // ===== impl Io ===== -pub(super) trait Io: AsyncRead + AsyncWrite + Unpin + 'static { +pub(super) trait Io: Read + Write + Unpin + 'static { fn __hyper_type_id(&self) -> TypeId { TypeId::of::() } } -impl Io for T {} +impl Io for T {} impl dyn Io + Send { fn __hyper_is(&self) -> bool { @@ -340,7 +340,9 @@ mod tests { fn upgraded_downcast() { let upgraded = Upgraded::new(Mock, Bytes::new()); - let upgraded = upgraded.downcast::>>().unwrap_err(); + let upgraded = upgraded + .downcast::>>>() + .unwrap_err(); upgraded.downcast::().unwrap(); } @@ -348,17 +350,17 @@ mod tests { // TODO: replace with tokio_test::io when it can test write_buf struct Mock; - impl AsyncRead for Mock { + impl Read for Mock { fn poll_read( self: Pin<&mut Self>, _cx: &mut task::Context<'_>, - _buf: &mut ReadBuf<'_>, + _buf: ReadBufCursor<'_>, ) -> Poll> { unreachable!("Mock::poll_read") } } - impl AsyncWrite for Mock { + impl Write for Mock { fn poll_write( self: Pin<&mut Self>, _: &mut task::Context<'_>, diff --git a/tests/client.rs b/tests/client.rs index 842282c5bb..ef80596c01 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -22,6 +22,7 @@ use hyper::{Method, Request, StatusCode, Uri, Version}; use bytes::Bytes; use futures_channel::oneshot; use futures_util::future::{self, FutureExt, TryFuture, TryFutureExt}; +use support::TokioIo; use tokio::net::TcpStream; mod support; @@ -36,8 +37,8 @@ where b.collect().await.map(|c| c.to_bytes()) } -fn tcp_connect(addr: &SocketAddr) -> impl Future> { - TcpStream::connect(*addr) +async fn tcp_connect(addr: &SocketAddr) -> std::io::Result> { + TcpStream::connect(*addr).await.map(TokioIo::new) } struct HttpInfo { @@ -312,7 +313,7 @@ macro_rules! test { req.headers_mut().append("Host", HeaderValue::from_str(&host).unwrap()); } - let (mut sender, conn) = builder.handshake(stream).await?; + let (mut sender, conn) = builder.handshake(TokioIo::new(stream)).await?; tokio::task::spawn(async move { if let Err(err) = conn.await { @@ -1339,7 +1340,7 @@ mod conn { use futures_util::future::{self, poll_fn, FutureExt, TryFutureExt}; use http_body_util::{BodyExt, Empty, StreamBody}; use hyper::rt::Timer; - use tokio::io::{AsyncRead, AsyncReadExt as _, AsyncWrite, AsyncWriteExt as _, ReadBuf}; + use tokio::io::{AsyncReadExt as _, AsyncWriteExt as _}; use tokio::net::{TcpListener as TkTcpListener, TcpStream}; use hyper::body::{Body, Frame}; @@ -1349,7 +1350,7 @@ mod conn { use super::{concat, s, support, tcp_connect, FutureHyperExt}; - use support::{TokioExecutor, TokioTimer}; + use support::{TokioExecutor, TokioIo, TokioTimer}; fn setup_logger() { let _ = pretty_env_logger::try_init(); @@ -1773,7 +1774,7 @@ mod conn { } let parts = conn.into_parts(); - let mut io = parts.io; + let io = parts.io; let buf = parts.read_buf; assert_eq!(buf, b"foobar=ready"[..]); @@ -1785,6 +1786,7 @@ mod conn { })) .unwrap_err(); + let mut io = io.tcp.inner(); let mut vec = vec![]; rt.block_on(io.write_all(b"foo=bar")).unwrap(); rt.block_on(io.read_to_end(&mut vec)).unwrap(); @@ -1861,7 +1863,7 @@ mod conn { } let parts = conn.into_parts(); - let mut io = parts.io; + let io = parts.io; let buf = parts.read_buf; assert_eq!(buf, b"foobar=ready"[..]); @@ -1874,6 +1876,7 @@ mod conn { })) .unwrap_err(); + let mut io = io.tcp.inner(); let mut vec = vec![]; rt.block_on(io.write_all(b"foo=bar")).unwrap(); rt.block_on(io.read_to_end(&mut vec)).unwrap(); @@ -1895,6 +1898,7 @@ mod conn { tokio::select! { res = listener.accept() => { let (stream, _) = res.unwrap(); + let stream = TokioIo::new(stream); let service = service_fn(|_:Request| future::ok::<_, hyper::Error>(Response::new(Empty::::new()))); @@ -2077,7 +2081,7 @@ mod conn { // Spawn an HTTP2 server that reads the whole body and responds tokio::spawn(async move { - let sock = listener.accept().await.unwrap().0; + let sock = TokioIo::new(listener.accept().await.unwrap().0); hyper::server::conn::http2::Builder::new(TokioExecutor) .timer(TokioTimer) .serve_connection( @@ -2166,7 +2170,7 @@ mod conn { let res = client.send_request(req).await.expect("send_request"); assert_eq!(res.status(), StatusCode::OK); - let mut upgraded = hyper::upgrade::on(res).await.unwrap(); + let mut upgraded = TokioIo::new(hyper::upgrade::on(res).await.unwrap()); let mut vec = vec![]; upgraded.read_to_end(&mut vec).await.unwrap(); @@ -2264,7 +2268,7 @@ mod conn { ); } - async fn drain_til_eof(mut sock: T) -> io::Result<()> { + async fn drain_til_eof(mut sock: T) -> io::Result<()> { let mut buf = [0u8; 1024]; loop { let n = sock.read(&mut buf).await?; @@ -2276,11 +2280,11 @@ mod conn { } struct DebugStream { - tcp: TcpStream, + tcp: TokioIo, shutdown_called: bool, } - impl AsyncWrite for DebugStream { + impl hyper::rt::Write for DebugStream { fn poll_shutdown( mut self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -2305,11 +2309,11 @@ mod conn { } } - impl AsyncRead for DebugStream { + impl hyper::rt::Read for DebugStream { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, + buf: hyper::rt::ReadBufCursor<'_>, ) -> Poll> { Pin::new(&mut self.tcp).poll_read(cx, buf) } diff --git a/tests/server.rs b/tests/server.rs index b412de038d..98ded22d73 100644 --- a/tests/server.rs +++ b/tests/server.rs @@ -22,8 +22,8 @@ use h2::{RecvStream, SendStream}; use http::header::{HeaderName, HeaderValue}; use http_body_util::{combinators::BoxBody, BodyExt, Empty, Full, StreamBody}; use hyper::rt::Timer; -use support::{TokioExecutor, TokioTimer}; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use hyper::rt::{Read as AsyncRead, Write as AsyncWrite}; +use support::{TokioExecutor, TokioIo, TokioTimer}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener as TkTcpListener, TcpListener, TcpStream as TkTcpStream}; @@ -975,6 +975,7 @@ async fn expect_continue_waits_for_body_poll() { }); let (socket, _) = listener.accept().await.expect("accept"); + let socket = TokioIo::new(socket); http1::Builder::new() .serve_connection( @@ -1154,6 +1155,7 @@ async fn disable_keep_alive_mid_request() { }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); let srv = http1::Builder::new().serve_connection(socket, HelloWorld); future::try_select(srv, rx1) .then(|r| match r { @@ -1201,7 +1203,7 @@ async fn disable_keep_alive_post_request() { let dropped2 = dropped.clone(); let (socket, _) = listener.accept().await.unwrap(); let transport = DebugStream { - stream: socket, + stream: TokioIo::new(socket), _debug: dropped2, }; let server = http1::Builder::new().serve_connection(transport, HelloWorld); @@ -1229,6 +1231,7 @@ async fn empty_parse_eof_does_not_return_error() { }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); http1::Builder::new() .serve_connection(socket, HelloWorld) .await @@ -1245,6 +1248,7 @@ async fn nonempty_parse_eof_returns_error() { }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); http1::Builder::new() .serve_connection(socket, HelloWorld) .await @@ -1268,6 +1272,7 @@ async fn http1_allow_half_close() { }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); http1::Builder::new() .half_close(true) .serve_connection( @@ -1295,6 +1300,7 @@ async fn disconnect_after_reading_request_before_responding() { }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); http1::Builder::new() .half_close(false) .serve_connection( @@ -1326,6 +1332,7 @@ async fn returning_1xx_response_is_error() { }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); http1::Builder::new() .serve_connection( socket, @@ -1390,6 +1397,7 @@ async fn header_read_timeout_slow_writes() { }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); let conn = http1::Builder::new() .timer(TokioTimer) .header_read_timeout(Duration::from_secs(5)) @@ -1465,6 +1473,7 @@ async fn header_read_timeout_slow_writes_multiple_requests() { }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); let conn = http1::Builder::new() .timer(TokioTimer) .header_read_timeout(Duration::from_secs(5)) @@ -1511,6 +1520,7 @@ async fn upgrades() { }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); let conn = http1::Builder::new().serve_connection( socket, service_fn(|_| { @@ -1529,7 +1539,7 @@ async fn upgrades() { // wait so that we don't write until other side saw 101 response rx.await.unwrap(); - let mut io = parts.io; + let mut io = parts.io.inner(); io.write_all(b"foo=bar").await.unwrap(); let mut vec = vec![]; io.read_to_end(&mut vec).await.unwrap(); @@ -1564,6 +1574,7 @@ async fn http_connect() { }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); let conn = http1::Builder::new().serve_connection( socket, service_fn(|_| { @@ -1581,7 +1592,7 @@ async fn http_connect() { // wait so that we don't write until other side saw 101 response rx.await.unwrap(); - let mut io = parts.io; + let mut io = parts.io.inner(); io.write_all(b"foo=bar").await.unwrap(); let mut vec = vec![]; io.read_to_end(&mut vec).await.unwrap(); @@ -1634,6 +1645,7 @@ async fn upgrades_new() { }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); http1::Builder::new() .serve_connection(socket, svc) .with_upgrades() @@ -1646,10 +1658,10 @@ async fn upgrades_new() { read_101_rx.await.unwrap(); let upgraded = on_upgrade.await.expect("on_upgrade"); - let parts = upgraded.downcast::().unwrap(); + let parts = upgraded.downcast::>().unwrap(); assert_eq!(parts.read_buf, "eagerly optimistic"); - let mut io = parts.io; + let mut io = parts.io.inner(); io.write_all(b"foo=bar").await.unwrap(); let mut vec = vec![]; io.read_to_end(&mut vec).await.unwrap(); @@ -1668,6 +1680,7 @@ async fn upgrades_ignored() { loop { let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); tokio::task::spawn(async move { http1::Builder::new() .serve_connection(socket, svc) @@ -1738,6 +1751,7 @@ async fn http_connect_new() { }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); http1::Builder::new() .serve_connection(socket, svc) .with_upgrades() @@ -1750,10 +1764,10 @@ async fn http_connect_new() { read_200_rx.await.unwrap(); let upgraded = on_upgrade.await.expect("on_upgrade"); - let parts = upgraded.downcast::().unwrap(); + let parts = upgraded.downcast::>().unwrap(); assert_eq!(parts.read_buf, "eagerly optimistic"); - let mut io = parts.io; + let mut io = parts.io.inner(); io.write_all(b"foo=bar").await.unwrap(); let mut vec = vec![]; io.read_to_end(&mut vec).await.unwrap(); @@ -1799,7 +1813,7 @@ async fn h2_connect() { let on_upgrade = hyper::upgrade::on(req); tokio::spawn(async move { - let mut upgraded = on_upgrade.await.expect("on_upgrade"); + let mut upgraded = TokioIo::new(on_upgrade.await.expect("on_upgrade")); upgraded.write_all(b"Bread?").await.unwrap(); let mut vec = vec![]; @@ -1818,6 +1832,7 @@ async fn h2_connect() { }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); http2::Builder::new(TokioExecutor) .serve_connection(socket, svc) //.with_upgrades() @@ -1891,7 +1906,7 @@ async fn h2_connect_multiplex() { assert!(upgrade_res.expect_err("upgrade cancelled").is_canceled()); return; } - let mut upgraded = upgrade_res.expect("upgrade successful"); + let mut upgraded = TokioIo::new(upgrade_res.expect("upgrade successful")); upgraded.write_all(b"Bread?").await.unwrap(); @@ -1927,6 +1942,7 @@ async fn h2_connect_multiplex() { }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); http2::Builder::new(TokioExecutor) .serve_connection(socket, svc) //.with_upgrades() @@ -1978,7 +1994,7 @@ async fn h2_connect_large_body() { let on_upgrade = hyper::upgrade::on(req); tokio::spawn(async move { - let mut upgraded = on_upgrade.await.expect("on_upgrade"); + let mut upgraded = TokioIo::new(on_upgrade.await.expect("on_upgrade")); upgraded.write_all(b"Bread?").await.unwrap(); let mut vec = vec![]; @@ -1999,6 +2015,7 @@ async fn h2_connect_large_body() { }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); http2::Builder::new(TokioExecutor) .serve_connection(socket, svc) //.with_upgrades() @@ -2049,7 +2066,7 @@ async fn h2_connect_empty_frames() { let on_upgrade = hyper::upgrade::on(req); tokio::spawn(async move { - let mut upgraded = on_upgrade.await.expect("on_upgrade"); + let mut upgraded = TokioIo::new(on_upgrade.await.expect("on_upgrade")); upgraded.write_all(b"Bread?").await.unwrap(); let mut vec = vec![]; @@ -2068,6 +2085,7 @@ async fn h2_connect_empty_frames() { }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); http2::Builder::new(TokioExecutor) .serve_connection(socket, svc) //.with_upgrades() @@ -2090,6 +2108,7 @@ async fn parse_errors_send_4xx_response() { }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); http1::Builder::new() .serve_connection(socket, HelloWorld) .await @@ -2112,6 +2131,7 @@ async fn illegal_request_length_returns_400_response() { }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); http1::Builder::new() .serve_connection(socket, HelloWorld) .await @@ -2152,6 +2172,7 @@ async fn max_buf_size() { }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); http1::Builder::new() .max_buf_size(MAX) .serve_connection(socket, HelloWorld) @@ -2166,6 +2187,7 @@ async fn graceful_shutdown_before_first_request_no_block() { tokio::spawn(async move { let socket = listener.accept().await.unwrap().0; + let socket = TokioIo::new(socket); let future = http1::Builder::new().serve_connection(socket, HelloWorld); pin!(future); @@ -2407,6 +2429,7 @@ async fn http2_keep_alive_detects_unresponsive_client() { }); let (socket, _) = listener.accept().await.expect("accept"); + let socket = TokioIo::new(socket); let err = http2::Builder::new(TokioExecutor) .timer(TokioTimer) @@ -2425,6 +2448,7 @@ async fn http2_keep_alive_with_responsive_client() { tokio::spawn(async move { let (socket, _) = listener.accept().await.expect("accept"); + let socket = TokioIo::new(socket); http2::Builder::new(TokioExecutor) .timer(TokioTimer) @@ -2435,7 +2459,7 @@ async fn http2_keep_alive_with_responsive_client() { .expect("serve_connection"); }); - let tcp = connect_async(addr).await; + let tcp = TokioIo::new(connect_async(addr).await); let (mut client, conn) = hyper::client::conn::http2::Builder::new(TokioExecutor) .handshake(tcp) .await @@ -2488,6 +2512,7 @@ async fn http2_keep_alive_count_server_pings() { tokio::spawn(async move { let (socket, _) = listener.accept().await.expect("accept"); + let socket = TokioIo::new(socket); http2::Builder::new(TokioExecutor) .timer(TokioTimer) @@ -2871,6 +2896,7 @@ impl ServeOptions { tokio::select! { res = listener.accept() => { let (stream, _) = res.unwrap(); + let stream = TokioIo::new(stream); tokio::task::spawn(async move { let msg_tx = msg_tx.clone(); @@ -2922,7 +2948,7 @@ fn has_header(msg: &str, name: &str) -> bool { msg[..n].contains(name) } -fn tcp_bind(addr: &SocketAddr) -> ::tokio::io::Result { +fn tcp_bind(addr: &SocketAddr) -> std::io::Result { let std_listener = StdTcpListener::bind(addr).unwrap(); std_listener.set_nonblocking(true).unwrap(); TcpListener::from_std(std_listener) @@ -3001,7 +3027,7 @@ impl AsyncRead for DebugStream { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, + buf: hyper::rt::ReadBufCursor<'_>, ) -> Poll> { Pin::new(&mut self.stream).poll_read(cx, buf) } @@ -3058,9 +3084,11 @@ impl TestClient { let host = req.uri().host().expect("uri has no host"); let port = req.uri().port_u16().expect("uri has no port"); - let stream = TkTcpStream::connect(format!("{}:{}", host, port)) - .await - .unwrap(); + let stream = TokioIo::new( + TkTcpStream::connect(format!("{}:{}", host, port)) + .await + .unwrap(), + ); if self.http2_only { let (mut sender, conn) = hyper::client::conn::http2::Builder::new(TokioExecutor) diff --git a/tests/support/mod.rs b/tests/support/mod.rs index e7e1e8c6bd..c46eff89ea 100644 --- a/tests/support/mod.rs +++ b/tests/support/mod.rs @@ -21,7 +21,7 @@ pub use hyper::{HeaderMap, StatusCode}; pub use std::net::SocketAddr; mod tokiort; -pub use tokiort::{TokioExecutor, TokioTimer}; +pub use tokiort::{TokioExecutor, TokioIo, TokioTimer}; #[allow(unused_macros)] macro_rules! t { @@ -357,6 +357,7 @@ async fn async_test(cfg: __TestConfig) { loop { let (stream, _) = listener.accept().await.expect("server error"); + let io = TokioIo::new(stream); // Move a clone into the service_fn let serve_handles = serve_handles.clone(); @@ -386,12 +387,12 @@ async fn async_test(cfg: __TestConfig) { tokio::task::spawn(async move { if http2_only { server::conn::http2::Builder::new(TokioExecutor) - .serve_connection(stream, service) + .serve_connection(io, service) .await .expect("server error"); } else { server::conn::http1::Builder::new() - .serve_connection(stream, service) + .serve_connection(io, service) .await .expect("server error"); } @@ -425,10 +426,11 @@ async fn async_test(cfg: __TestConfig) { async move { let stream = TcpStream::connect(addr).await.unwrap(); + let io = TokioIo::new(stream); let res = if http2_only { let (mut sender, conn) = hyper::client::conn::http2::Builder::new(TokioExecutor) - .handshake(stream) + .handshake(io) .await .unwrap(); @@ -440,7 +442,7 @@ async fn async_test(cfg: __TestConfig) { sender.send_request(req).await.unwrap() } else { let (mut sender, conn) = hyper::client::conn::http1::Builder::new() - .handshake(stream) + .handshake(io) .await .unwrap(); @@ -508,6 +510,7 @@ async fn naive_proxy(cfg: ProxyConfig) -> (SocketAddr, impl Future) loop { let (stream, _) = listener.accept().await.unwrap(); + let io = TokioIo::new(stream); let service = service_fn(move |mut req| { async move { @@ -523,11 +526,12 @@ async fn naive_proxy(cfg: ProxyConfig) -> (SocketAddr, impl Future) let stream = TcpStream::connect(format!("{}:{}", uri, port)) .await .unwrap(); + let io = TokioIo::new(stream); let resp = if http2_only { let (mut sender, conn) = hyper::client::conn::http2::Builder::new(TokioExecutor) - .handshake(stream) + .handshake(io) .await .unwrap(); @@ -540,7 +544,7 @@ async fn naive_proxy(cfg: ProxyConfig) -> (SocketAddr, impl Future) sender.send_request(req).await? } else { let builder = hyper::client::conn::http1::Builder::new(); - let (mut sender, conn) = builder.handshake(stream).await.unwrap(); + let (mut sender, conn) = builder.handshake(io).await.unwrap(); tokio::task::spawn(async move { if let Err(err) = conn.await { @@ -569,12 +573,12 @@ async fn naive_proxy(cfg: ProxyConfig) -> (SocketAddr, impl Future) if http2_only { server::conn::http2::Builder::new(TokioExecutor) - .serve_connection(stream, service) + .serve_connection(io, service) .await .unwrap(); } else { server::conn::http1::Builder::new() - .serve_connection(stream, service) + .serve_connection(io, service) .await .unwrap(); }