diff --git a/src/client/pool.rs b/src/client/pool.rs index b112979f64..5a1afee541 100644 --- a/src/client/pool.rs +++ b/src/client/pool.rs @@ -3,7 +3,7 @@ use std::collections::{HashMap, VecDeque}; use std::fmt; use std::io; use std::ops::{Deref, DerefMut, BitAndAssign}; -use std::rc::Rc; +use std::rc::{Rc, Weak}; use std::time::{Duration, Instant}; use futures::{Future, Async, Poll}; @@ -103,7 +103,7 @@ impl Pool { status: Rc::new(Cell::new(TimedKA::Busy)), }, key: key, - pool: self.clone(), + pool: Rc::downgrade(&self.inner), } } @@ -118,7 +118,7 @@ impl Pool { Pooled { entry: entry, key: key, - pool: self.clone(), + pool: Rc::downgrade(&self.inner), } } @@ -161,7 +161,7 @@ impl Clone for Pool { pub struct Pooled { entry: Entry, key: Rc, - pool: Pool, + pool: Weak>>, } impl Deref for Pooled { @@ -194,8 +194,16 @@ impl KeepAlive for Pooled { return; } self.entry.is_reused = true; - if self.pool.is_enabled() { - self.pool.put(self.key.clone(), self.entry.clone()); + if let Some(inner) = self.pool.upgrade() { + let mut pool = Pool { + inner: inner, + }; + if pool.is_enabled() { + pool.put(self.key.clone(), self.entry.clone()); + } + } else { + trace!("pool dropped, dropping pooled ({:?})", self.key); + self.entry.status.set(TimedKA::Disabled); } } diff --git a/src/proto/conn.rs b/src/proto/conn.rs index e6e09dd45c..3530bed552 100644 --- a/src/proto/conn.rs +++ b/src/proto/conn.rs @@ -235,6 +235,9 @@ where I: AsyncRead + AsyncWrite, // // When writing finishes, we need to wake the task up in case there // is more reading that can be done, to start a new message. + + + let wants_read = match self.state.reading { Reading::Body(..) | Reading::KeepAlive => return, @@ -242,13 +245,19 @@ where I: AsyncRead + AsyncWrite, Reading::Closed => false, }; - match self.state.writing { + let wants_write = match self.state.writing { Writing::Continue(..) | Writing::Body(..) | Writing::Ending(..) => return, - Writing::Init | - Writing::KeepAlive | - Writing::Closed => (), + Writing::Init => true, + Writing::KeepAlive => false, + Writing::Closed => false, + }; + + // if the client is at Reading::Init and Writing::Init, + // it's not actually looking for a read, but a write. + if wants_write && !T::should_read_first() { + return; } if !self.io.is_read_blocked() { @@ -704,9 +713,13 @@ impl State { fn idle(&mut self) { self.method = None; - self.reading = Reading::Init; - self.writing = Writing::Init; self.keep_alive.idle(); + if self.is_idle() { + self.reading = Reading::Init; + self.writing = Writing::Init; + } else { + self.close(); + } } fn is_idle(&self) -> bool { diff --git a/tests/client.rs b/tests/client.rs index 2120b9dc32..c3a7af50a1 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -543,57 +543,115 @@ fn client_pooled_socket_disconnected() { } */ -#[test] -fn drop_body_before_eof_closes_connection() { - // /~https://github.com/hyperium/hyper/issues/1353 +mod dispatch_impl { + use super::*; use std::io::{self, Read, Write}; use std::sync::Arc; use std::sync::atomic::{AtomicUsize, Ordering}; + use std::thread; use std::time::Duration; + + use futures::{self, Future}; + use futures::sync::oneshot; use tokio_core::reactor::{Timeout}; use tokio_core::net::TcpStream; use tokio_io::{AsyncRead, AsyncWrite}; + use hyper::client::HttpConnector; use hyper::server::Service; - use hyper::Uri; + use hyper::{Client, Uri}; + use hyper; - let _ = pretty_env_logger::init(); - let server = TcpListener::bind("127.0.0.1:0").unwrap(); - let addr = server.local_addr().unwrap(); - let mut core = Core::new().unwrap(); - let handle = core.handle(); - let closes = Arc::new(AtomicUsize::new(0)); - let client = Client::configure() - .connector(DebugConnector(HttpConnector::new(1, &core.handle()), closes.clone())) - .no_proto() - .build(&handle); - let (tx1, rx1) = oneshot::channel(); + #[test] + fn drop_body_before_eof_closes_connection() { + // /~https://github.com/hyperium/hyper/issues/1353 + let _ = pretty_env_logger::init(); - thread::spawn(move || { - let mut sock = server.accept().unwrap().0; - sock.set_read_timeout(Some(Duration::from_secs(5))).unwrap(); - sock.set_write_timeout(Some(Duration::from_secs(5))).unwrap(); - let mut buf = [0; 4096]; - sock.read(&mut buf).expect("read 1"); - let body = vec![b'x'; 1024 * 128]; - write!(sock, "HTTP/1.1 200 OK\r\nContent-Length: {}\r\n\r\n", body.len()).expect("write head"); - let _ = sock.write_all(&body); - let _ = tx1.send(()); - }); + let server = TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = server.local_addr().unwrap(); + let mut core = Core::new().unwrap(); + let handle = core.handle(); + let closes = Arc::new(AtomicUsize::new(0)); + let client = Client::configure() + .connector(DebugConnector(HttpConnector::new(1, &core.handle()), closes.clone())) + .no_proto() + .build(&handle); + + let (tx1, rx1) = oneshot::channel(); + + thread::spawn(move || { + let mut sock = server.accept().unwrap().0; + sock.set_read_timeout(Some(Duration::from_secs(5))).unwrap(); + sock.set_write_timeout(Some(Duration::from_secs(5))).unwrap(); + let mut buf = [0; 4096]; + sock.read(&mut buf).expect("read 1"); + let body = vec![b'x'; 1024 * 128]; + write!(sock, "HTTP/1.1 200 OK\r\nContent-Length: {}\r\n\r\n", body.len()).expect("write head"); + let _ = sock.write_all(&body); + let _ = tx1.send(()); + }); - let uri = format!("http://{}/a", addr).parse().unwrap(); + let uri = format!("http://{}/a", addr).parse().unwrap(); - let res = client.get(uri).and_then(move |res| { - assert_eq!(res.status(), hyper::StatusCode::Ok); - Timeout::new(Duration::from_secs(1), &handle).unwrap() - .from_err() - }); - let rx = rx1.map_err(|_| hyper::Error::Io(io::Error::new(io::ErrorKind::Other, "thread panicked"))); - core.run(res.join(rx).map(|r| r.0)).unwrap(); + let res = client.get(uri).and_then(move |res| { + assert_eq!(res.status(), hyper::StatusCode::Ok); + Timeout::new(Duration::from_secs(1), &handle).unwrap() + .from_err() + }); + let rx = rx1.map_err(|_| hyper::Error::Io(io::Error::new(io::ErrorKind::Other, "thread panicked"))); + core.run(res.join(rx).map(|r| r.0)).unwrap(); + + assert_eq!(closes.load(Ordering::Relaxed), 1); + } + + #[test] + fn drop_client_closes_connection() { + // /~https://github.com/hyperium/hyper/issues/1353 + let _ = pretty_env_logger::init(); - assert_eq!(closes.load(Ordering::Relaxed), 1); + let server = TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = server.local_addr().unwrap(); + let mut core = Core::new().unwrap(); + let handle = core.handle(); + let closes = Arc::new(AtomicUsize::new(0)); + + let (tx1, rx1) = oneshot::channel(); + + thread::spawn(move || { + let mut sock = server.accept().unwrap().0; + sock.set_read_timeout(Some(Duration::from_secs(5))).unwrap(); + sock.set_write_timeout(Some(Duration::from_secs(5))).unwrap(); + let mut buf = [0; 4096]; + sock.read(&mut buf).expect("read 1"); + let body =[b'x'; 64]; + write!(sock, "HTTP/1.1 200 OK\r\nContent-Length: {}\r\n\r\n", body.len()).expect("write head"); + let _ = sock.write_all(&body); + let _ = tx1.send(()); + }); + + let uri = format!("http://{}/a", addr).parse().unwrap(); + + let res = { + let client = Client::configure() + .connector(DebugConnector(HttpConnector::new(1, &handle), closes.clone())) + .no_proto() + .build(&handle); + client.get(uri).and_then(move |res| { + assert_eq!(res.status(), hyper::StatusCode::Ok); + res.body().concat2() + }).and_then(|_| { + Timeout::new(Duration::from_secs(1), &handle).unwrap() + .from_err() + }) + }; + // client is dropped + let rx = rx1.map_err(|_| hyper::Error::Io(io::Error::new(io::ErrorKind::Other, "thread panicked"))); + core.run(res.join(rx).map(|r| r.0)).unwrap(); + + assert_eq!(closes.load(Ordering::Relaxed), 1); + }