From d22deb6572c279e11773b6bcb862415c08f19c2e Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Tue, 23 Jan 2018 16:09:17 -0800 Subject: [PATCH] feat(server): add `Http::max_buf_size()` option The internal connection's read and write bufs will be restricted from growing bigger than the configured `max_buf_size`. Closes #1368 --- src/proto/conn.rs | 6 ++++- src/proto/io.rs | 45 ++++++++++++++++++++++++-------------- src/server/mod.rs | 12 ++++++++++ src/server/server_proto.rs | 3 +++ tests/server.rs | 35 +++++++++++++++++++++++++++++ 5 files changed, 84 insertions(+), 17 deletions(-) diff --git a/src/proto/conn.rs b/src/proto/conn.rs index c7999488c2..0ecbb5d8c1 100644 --- a/src/proto/conn.rs +++ b/src/proto/conn.rs @@ -58,6 +58,10 @@ where I: AsyncRead + AsyncWrite, self.io.set_flush_pipeline(enabled); } + pub fn set_max_buf_size(&mut self, max: usize) { + self.io.set_max_buf_size(max); + } + #[cfg(feature = "tokio-proto")] fn poll_incoming(&mut self) -> Poll, super::Chunk, ::Error>>, io::Error> { trace!("Conn::poll_incoming()"); @@ -1221,7 +1225,7 @@ mod tests { let _: Result<(), ()> = future::lazy(|| { let io = AsyncIo::new_buf(vec![], 0); let mut conn = Conn::<_, proto::Chunk, ServerTransaction>::new(io, Default::default()); - let max = ::proto::io::MAX_BUFFER_SIZE + 4096; + let max = ::proto::io::DEFAULT_MAX_BUFFER_SIZE + 4096; conn.state.writing = Writing::Body(Encoder::length((max * 2) as u64), None); assert!(conn.start_send(Frame::Body { chunk: Some(vec![b'a'; 1024 * 8].into()) }).unwrap().is_ready()); diff --git a/src/proto/io.rs b/src/proto/io.rs index 4e63ba5e4b..06b86c91a9 100644 --- a/src/proto/io.rs +++ b/src/proto/io.rs @@ -10,11 +10,12 @@ use super::{Http1Transaction, MessageHead}; use bytes::{BytesMut, Bytes}; const INIT_BUFFER_SIZE: usize = 8192; -pub const MAX_BUFFER_SIZE: usize = 8192 + 4096 * 100; +pub const DEFAULT_MAX_BUFFER_SIZE: usize = 8192 + 4096 * 100; pub struct Buffered { flush_pipeline: bool, io: T, + max_buf_size: usize, read_blocked: bool, read_buf: BytesMut, write_buf: WriteBuf, @@ -34,6 +35,7 @@ impl Buffered { Buffered { flush_pipeline: false, io: io, + max_buf_size: DEFAULT_MAX_BUFFER_SIZE, read_buf: BytesMut::with_capacity(0), write_buf: WriteBuf::new(), read_blocked: false, @@ -44,6 +46,11 @@ impl Buffered { self.flush_pipeline = enabled; } + pub fn set_max_buf_size(&mut self, max: usize) { + self.max_buf_size = max; + self.write_buf.max_buf_size = max; + } + pub fn read_buf(&self) -> &[u8] { self.read_buf.as_ref() } @@ -51,7 +58,7 @@ impl Buffered { pub fn write_buf_mut(&mut self) -> &mut Vec { self.write_buf.maybe_reset(); self.write_buf.maybe_reserve(0); - &mut self.write_buf.0.bytes + &mut self.write_buf.buf.bytes } pub fn consume_leading_lines(&mut self) { @@ -75,8 +82,8 @@ impl Buffered { return Ok(Async::Ready(head)) }, None => { - if self.read_buf.capacity() >= MAX_BUFFER_SIZE { - debug!("MAX_BUFFER_SIZE reached, closing"); + if self.read_buf.capacity() >= self.max_buf_size { + debug!("max_buf_size ({}) reached, closing", self.max_buf_size); return Err(::Error::TooLarge); } }, @@ -259,22 +266,28 @@ impl AtomicWrite for T { // an internal buffer to collect writes before flushes #[derive(Debug)] -struct WriteBuf(Cursor>); +struct WriteBuf{ + buf: Cursor>, + max_buf_size: usize, +} impl WriteBuf { fn new() -> WriteBuf { - WriteBuf(Cursor::new(Vec::new())) + WriteBuf { + buf: Cursor::new(Vec::new()), + max_buf_size: DEFAULT_MAX_BUFFER_SIZE, + } } fn write_into(&mut self, w: &mut W) -> io::Result { - self.0.write_to(w) + self.buf.write_to(w) } fn buffer(&mut self, data: &[u8]) -> usize { trace!("WriteBuf::buffer() len = {:?}", data.len()); self.maybe_reset(); self.maybe_reserve(data.len()); - let vec = &mut self.0.bytes; + let vec = &mut self.buf.bytes; let len = cmp::min(vec.capacity() - vec.len(), data.len()); assert!(vec.capacity() - vec.len() >= len); unsafe { @@ -291,28 +304,28 @@ impl WriteBuf { } fn remaining(&self) -> usize { - self.0.remaining() + self.buf.remaining() } #[inline] fn maybe_reserve(&mut self, needed: usize) { - let vec = &mut self.0.bytes; + let vec = &mut self.buf.bytes; let cap = vec.capacity(); if cap == 0 { - let init = cmp::min(MAX_BUFFER_SIZE, cmp::max(INIT_BUFFER_SIZE, needed)); + let init = cmp::min(self.max_buf_size, cmp::max(INIT_BUFFER_SIZE, needed)); trace!("WriteBuf reserving initial {}", init); vec.reserve(init); - } else if cap < MAX_BUFFER_SIZE { - vec.reserve(cmp::min(needed, MAX_BUFFER_SIZE - cap)); + } else if cap < self.max_buf_size { + vec.reserve(cmp::min(needed, self.max_buf_size - cap)); trace!("WriteBuf reserved {}", vec.capacity() - cap); } } fn maybe_reset(&mut self) { - if self.0.pos != 0 && self.0.remaining() == 0 { - self.0.pos = 0; + if self.buf.pos != 0 && self.buf.remaining() == 0 { + self.buf.pos = 0; unsafe { - self.0.bytes.set_len(0); + self.buf.bytes.set_len(0); } } } diff --git a/src/server/mod.rs b/src/server/mod.rs index dae0d9b9a1..cbec7264ea 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -53,6 +53,7 @@ pub use self::service::{const_service, service_fn}; /// which handle a connection to an HTTP server. Each instance of `Http` can be /// configured with various protocol-level options such as keepalive. pub struct Http { + max_buf_size: Option, keep_alive: bool, pipeline: bool, _marker: PhantomData, @@ -129,6 +130,7 @@ impl + 'static> Http { pub fn new() -> Http { Http { keep_alive: true, + max_buf_size: None, pipeline: false, _marker: PhantomData, } @@ -142,6 +144,12 @@ impl + 'static> Http { self } + /// Set the maximum buffer size for the connection. + pub fn max_buf_size(&mut self, max: usize) -> &mut Self { + self.max_buf_size = Some(max); + self + } + /// Aggregates flushes to better support pipelined responses. /// /// Experimental, may be have bugs. @@ -226,6 +234,7 @@ impl + 'static> Http { new_service: new_service, protocol: Http { keep_alive: self.keep_alive, + max_buf_size: self.max_buf_size, pipeline: self.pipeline, _marker: PhantomData, }, @@ -250,6 +259,9 @@ impl + 'static> Http { }; let mut conn = proto::Conn::new(io, ka); conn.set_flush_pipeline(self.pipeline); + if let Some(max) = self.max_buf_size { + conn.set_max_buf_size(max); + } Connection { conn: proto::dispatch::Dispatcher::new(proto::dispatch::Server::new(service), conn), } diff --git a/src/server/server_proto.rs b/src/server/server_proto.rs index 5afa347ba0..28563e19eb 100644 --- a/src/server/server_proto.rs +++ b/src/server/server_proto.rs @@ -113,6 +113,9 @@ impl ServerProto for Http }; let mut conn = proto::Conn::new(io, ka); conn.set_flush_pipeline(self.pipeline); + if let Some(max) = self.max_buf_size { + conn.set_max_buf_size(max); + } __ProtoBindTransport { inner: future::ok(conn), } diff --git a/tests/server.rs b/tests/server.rs index d6a70f4590..8af92dd067 100644 --- a/tests/server.rs +++ b/tests/server.rs @@ -958,6 +958,41 @@ fn illegal_request_length_returns_400_response() { core.run(fut).unwrap_err(); } +#[test] +fn max_buf_size() { + let _ = pretty_env_logger::try_init(); + let mut core = Core::new().unwrap(); + let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap(), &core.handle()).unwrap(); + let addr = listener.local_addr().unwrap(); + + const MAX: usize = 16_000; + + thread::spawn(move || { + let mut tcp = connect(&addr); + tcp.write_all(b"POST /").expect("write 1"); + tcp.write_all(&vec![b'a'; MAX]).expect("write 2"); + tcp.write_all(b" HTTP/1.1\r\n\r\n").expect("write 3"); + let mut buf = [0; 256]; + tcp.read(&mut buf).expect("read 1"); + + let expected = "HTTP/1.1 400 "; + assert_eq!(s(&buf[..expected.len()]), expected); + }); + + let fut = listener.incoming() + .into_future() + .map_err(|_| unreachable!()) + .and_then(|(item, _incoming)| { + let (socket, _) = item.unwrap(); + Http::::new() + .max_buf_size(MAX) + .serve_connection(socket, HelloWorld) + .map(|_| ()) + }); + + core.run(fut).unwrap_err(); +} + #[test] fn remote_addr() { let server = serve();