From be4e7181456844180963d0e5234656c319ce92a6 Mon Sep 17 00:00:00 2001 From: Marko Lalic Date: Sat, 5 Sep 2015 15:29:31 +0200 Subject: [PATCH] fix(http): Add a stream enum that makes it impossible to lose a stream This removes a number of possible panics... --- src/http/h1.rs | 456 +++++++++++++++++++++++++++++-------------------- 1 file changed, 267 insertions(+), 189 deletions(-) diff --git a/src/http/h1.rs b/src/http/h1.rs index 239f96708b..99f13a73d2 100644 --- a/src/http/h1.rs +++ b/src/http/h1.rs @@ -35,19 +35,70 @@ use version; const MAX_INVALID_RESPONSE_BYTES: usize = 1024 * 128; +#[derive(Debug)] +struct Wrapper { + obj: Option, +} + +impl Wrapper { + pub fn new(obj: T) -> Wrapper { + Wrapper { obj: Some(obj) } + } + + pub fn map_in_place(&mut self, f: F) where F: FnOnce(T) -> T { + let obj = self.obj.take().unwrap(); + let res = f(obj); + self.obj = Some(res); + } + + pub fn into_inner(self) -> T { self.obj.unwrap() } + pub fn as_mut(&mut self) -> &mut T { self.obj.as_mut().unwrap() } + pub fn as_ref(&self) -> &T { self.obj.as_ref().unwrap() } +} + +#[derive(Debug)] +enum Stream { + Idle(Box), + Writing(HttpWriter>>), + Reading(HttpReader>>), +} + +impl Stream { + fn writer_mut(&mut self) -> Option<&mut HttpWriter>>> { + match *self { + Stream::Writing(ref mut writer) => Some(writer), + _ => None, + } + } + fn reader_mut(&mut self) -> Option<&mut HttpReader>>> { + match *self { + Stream::Reading(ref mut reader) => Some(reader), + _ => None, + } + } + fn reader_ref(&self) -> Option<&HttpReader>>> { + match *self { + Stream::Reading(ref reader) => Some(reader), + _ => None, + } + } + + fn new(stream: Box) -> Stream { + Stream::Idle(stream) + } +} + /// An implementation of the `HttpMessage` trait for HTTP/1.1. #[derive(Debug)] pub struct Http11Message { method: Option, - stream: Option>, - writer: Option>>>, - reader: Option>>>, + stream: Wrapper, } impl Write for Http11Message { #[inline] fn write(&mut self, buf: &[u8]) -> io::Result { - match self.writer { + match self.stream.as_mut().writer_mut() { None => Err(io::Error::new(io::ErrorKind::Other, "Not in a writable state")), Some(ref mut writer) => writer.write(buf), @@ -55,7 +106,7 @@ impl Write for Http11Message { } #[inline] fn flush(&mut self) -> io::Result<()> { - match self.writer { + match self.stream.as_mut().writer_mut() { None => Err(io::Error::new(io::ErrorKind::Other, "Not in a writable state")), Some(ref mut writer) => writer.flush(), @@ -66,7 +117,7 @@ impl Write for Http11Message { impl Read for Http11Message { #[inline] fn read(&mut self, buf: &mut [u8]) -> io::Result { - match self.reader { + match self.stream.as_mut().reader_mut() { None => Err(io::Error::new(io::ErrorKind::Other, "Not in a readable state")), Some(ref mut reader) => reader.read(buf), @@ -76,178 +127,215 @@ impl Read for Http11Message { impl HttpMessage for Http11Message { fn set_outgoing(&mut self, mut head: RequestHead) -> ::Result { - let stream = match self.stream.take() { - Some(stream) => stream, - None => { - return Err(From::from(io::Error::new( + let mut res = Err(Error::from(io::Error::new( io::ErrorKind::Other, - "Message not idle, cannot start new outgoing"))); - } - }; - let mut stream = BufWriter::new(stream); - - let mut uri = head.url.serialize_path().unwrap(); - if let Some(ref q) = head.url.query { - uri.push('?'); - uri.push_str(&q[..]); - } + ""))); + let mut method = None; + self.stream.map_in_place(|stream: Stream| -> Stream { + let stream = match stream { + Stream::Idle(stream) => stream, + _ => { + res = Err(Error::from(io::Error::new( + io::ErrorKind::Other, + "Message not idle, cannot start new outgoing"))); + return stream; + }, + }; + let mut stream = BufWriter::new(stream); - let version = version::HttpVersion::Http11; - debug!("request line: {:?} {:?} {:?}", head.method, uri, version); - try!(write!(&mut stream, "{} {} {}{}", - head.method, uri, version, LINE_ENDING)); + let mut uri = head.url.serialize_path().unwrap(); + if let Some(ref q) = head.url.query { + uri.push('?'); + uri.push_str(&q[..]); + } - let stream = { - let mut write_headers = |mut stream: BufWriter>, head: &RequestHead| { - debug!("headers={:?}", head.headers); - match write!(&mut stream, "{}{}", head.headers, LINE_ENDING) { - Ok(_) => Ok(stream), - Err(e) => { - self.stream = Some(stream.into_inner().unwrap()); - Err(e) + let version = version::HttpVersion::Http11; + debug!("request line: {:?} {:?} {:?}", head.method, uri, version); + match write!(&mut stream, "{} {} {}{}", + head.method, uri, version, LINE_ENDING) { + Err(e) => { + res = Err(From::from(e)); + // TODO What should we do if the BufWriter doesn't wanna + // relinquish the stream? + return Stream::Idle(stream.into_inner().ok().unwrap()); + }, + Ok(_) => {}, + }; + + let stream = { + let write_headers = |mut stream: BufWriter>, head: &RequestHead| { + debug!("headers={:?}", head.headers); + match write!(&mut stream, "{}{}", head.headers, LINE_ENDING) { + Ok(_) => Ok(stream), + Err(e) => { + Err((e, stream.into_inner().unwrap())) + } } - } - }; - match &head.method { - &Method::Get | &Method::Head => { - EmptyWriter(try!(write_headers(stream, &head))) - }, - _ => { - let mut chunked = true; - let mut len = 0; - - match head.headers.get::() { - Some(cl) => { - chunked = false; - len = **cl; - }, - None => () - }; - - // can't do in match above, thanks borrowck - if chunked { - let encodings = match head.headers.get_mut::() { - Some(encodings) => { - //TODO: check if chunked is already in encodings. use HashSet? - encodings.push(header::Encoding::Chunked); - false + }; + match &head.method { + &Method::Get | &Method::Head => { + let writer = match write_headers(stream, &head) { + Ok(w) => w, + Err(e) => { + res = Err(From::from(e.0)); + return Stream::Idle(e.1); + } + }; + EmptyWriter(writer) + }, + _ => { + let mut chunked = true; + let mut len = 0; + + match head.headers.get::() { + Some(cl) => { + chunked = false; + len = **cl; }, - None => true + None => () }; - if encodings { - head.headers.set( - header::TransferEncoding(vec![header::Encoding::Chunked])) + // can't do in match above, thanks borrowck + if chunked { + let encodings = match head.headers.get_mut::() { + Some(encodings) => { + //TODO: check if chunked is already in encodings. use HashSet? + encodings.push(header::Encoding::Chunked); + false + }, + None => true + }; + + if encodings { + head.headers.set( + header::TransferEncoding(vec![header::Encoding::Chunked])) + } } - } - let stream = try!(write_headers(stream, &head)); + let stream = match write_headers(stream, &head) { + Ok(s) => s, + Err(e) => { + res = Err(From::from(e.0)); + return Stream::Idle(e.1); + }, + }; - if chunked { - ChunkedWriter(stream) - } else { - SizedWriter(stream, len) + if chunked { + ChunkedWriter(stream) + } else { + SizedWriter(stream, len) + } } } - } - }; + }; - self.writer = Some(stream); - self.method = Some(head.method.clone()); + method = Some(head.method.clone()); + res = Ok(head); + Stream::Writing(stream) + }); - Ok(head) + self.method = method; + res } fn get_incoming(&mut self) -> ::Result { try!(self.flush_outgoing()); - let stream = match self.stream.take() { - Some(stream) => stream, - None => { - // The message was already in the reading state... - // TODO Decide what happens in case we try to get a new incoming at that point - return Err(From::from( + let method = self.method.take().unwrap_or(Method::Get); + let mut res = Err(From::from( io::Error::new(io::ErrorKind::Other, "Read already in progress"))); - } - }; - - let expected_no_content = stream.previous_response_expected_no_content(); - trace!("previous_response_expected_no_content = {}", expected_no_content); - - let mut stream = BufReader::new(stream); - - let mut invalid_bytes_read = 0; - let head; - loop { - head = match parse_response(&mut stream) { - Ok(head) => head, - Err(::Error::Version) - if expected_no_content && invalid_bytes_read < MAX_INVALID_RESPONSE_BYTES => { - trace!("expected_no_content, found content"); - invalid_bytes_read += 1; - stream.consume(1); - continue; - } - Err(e) => { - self.stream = Some(stream.into_inner()); - return Err(e); + self.stream.map_in_place(|stream| { + let stream = match stream { + Stream::Idle(stream) => stream, + _ => { + // The message was already in the reading state... + // TODO Decide what happens in case we try to get a new incoming at that point + res = Err(From::from( + io::Error::new(io::ErrorKind::Other, + "Read already in progress"))); + return stream; } }; - break; - } - - let raw_status = head.subject; - let headers = head.headers; - let method = self.method.take().unwrap_or(Method::Get); + let expected_no_content = stream.previous_response_expected_no_content(); + trace!("previous_response_expected_no_content = {}", expected_no_content); + + let mut stream = BufReader::new(stream); + + let mut invalid_bytes_read = 0; + let head; + loop { + head = match parse_response(&mut stream) { + Ok(head) => head, + Err(::Error::Version) + if expected_no_content && invalid_bytes_read < MAX_INVALID_RESPONSE_BYTES => { + trace!("expected_no_content, found content"); + invalid_bytes_read += 1; + stream.consume(1); + continue; + } + Err(e) => { + res = Err(e); + return Stream::Idle(stream.into_inner()); + } + }; + break; + } - let is_empty = !should_have_response_body(&method, raw_status.0); - stream.get_mut().set_previous_response_expected_no_content(is_empty); - // According to https://tools.ietf.org/html/rfc7230#section-3.3.3 - // 1. HEAD reponses, and Status 1xx, 204, and 304 cannot have a body. - // 2. Status 2xx to a CONNECT cannot have a body. - // 3. Transfer-Encoding: chunked has a chunked body. - // 4. If multiple differing Content-Length headers or invalid, close connection. - // 5. Content-Length header has a sized body. - // 6. Not Client. - // 7. Read till EOF. - self.reader = Some(if is_empty { - EmptyReader(stream) - } else { - if let Some(&TransferEncoding(ref codings)) = headers.get() { - if codings.last() == Some(&Chunked) { - ChunkedReader(stream, None) + let raw_status = head.subject; + let headers = head.headers; + + let is_empty = !should_have_response_body(&method, raw_status.0); + stream.get_mut().set_previous_response_expected_no_content(is_empty); + // According to https://tools.ietf.org/html/rfc7230#section-3.3.3 + // 1. HEAD reponses, and Status 1xx, 204, and 304 cannot have a body. + // 2. Status 2xx to a CONNECT cannot have a body. + // 3. Transfer-Encoding: chunked has a chunked body. + // 4. If multiple differing Content-Length headers or invalid, close connection. + // 5. Content-Length header has a sized body. + // 6. Not Client. + // 7. Read till EOF. + let reader = if is_empty { + EmptyReader(stream) + } else { + if let Some(&TransferEncoding(ref codings)) = headers.get() { + if codings.last() == Some(&Chunked) { + ChunkedReader(stream, None) + } else { + trace!("not chuncked. read till eof"); + EofReader(stream) + } + } else if let Some(&ContentLength(len)) = headers.get() { + SizedReader(stream, len) + } else if headers.has::() { + trace!("illegal Content-Length: {:?}", headers.get_raw("Content-Length")); + res = Err(Error::Header); + return Stream::Idle(stream.into_inner()); } else { - trace!("not chuncked. read till eof"); + trace!("neither Transfer-Encoding nor Content-Length"); EofReader(stream) } - } else if let Some(&ContentLength(len)) = headers.get() { - SizedReader(stream, len) - } else if headers.has::() { - trace!("illegal Content-Length: {:?}", headers.get_raw("Content-Length")); - self.stream = Some(stream.into_inner()); - return Err(Error::Header); - } else { - trace!("neither Transfer-Encoding nor Content-Length"); - EofReader(stream) - } - }); + }; - trace!("Http11Message.reader = {:?}", self.reader); + trace!("Http11Message.reader = {:?}", reader); - Ok(ResponseHead { - headers: headers, - raw_status: raw_status, - version: head.version, - }) + res = Ok(ResponseHead { + headers: headers, + raw_status: raw_status, + version: head.version, + }); + + Stream::Reading(reader) + }); + res } fn has_body(&self) -> bool { - match self.reader { - Some(EmptyReader(..)) | - Some(SizedReader(_, 0)) | - Some(ChunkedReader(_, Some(0))) => false, + match self.stream.as_ref().reader_ref() { + Some(&EmptyReader(..)) | + Some(&SizedReader(_, 0)) | + Some(&ChunkedReader(_, Some(0))) => false, // specifically EofReader is always true _ => true } @@ -274,43 +362,31 @@ impl HttpMessage for Http11Message { impl Http11Message { /// Consumes the `Http11Message` and returns the underlying `NetworkStream`. - pub fn into_inner(mut self) -> Box { - if self.stream.is_some() { - self.stream.take().unwrap() - } else if self.writer.is_some() { - self.writer.take().unwrap().into_inner().into_inner().unwrap() - } else if self.reader.is_some() { - self.reader.take().unwrap().into_inner().into_inner() - } else { - panic!("Http11Message lost its underlying stream somehow"); + pub fn into_inner(self) -> Box { + match self.stream.into_inner() { + Stream::Idle(stream) => stream, + Stream::Writing(stream) => stream.into_inner().into_inner().unwrap(), + Stream::Reading(stream) => stream.into_inner().into_inner(), } } /// Gets a mutable reference to the underlying `NetworkStream`, regardless of the state of the /// `Http11Message`. pub fn get_ref(&self) -> &(NetworkStream + Send) { - if self.stream.is_some() { - &**self.stream.as_ref().unwrap() - } else if self.writer.is_some() { - &**self.writer.as_ref().unwrap().get_ref().get_ref() - } else if self.reader.is_some() { - &**self.reader.as_ref().unwrap().get_ref().get_ref() - } else { - panic!("Http11Message lost its underlying stream somehow"); + match *self.stream.as_ref() { + Stream::Idle(ref stream) => &**stream, + Stream::Writing(ref stream) => &**stream.get_ref().get_ref(), + Stream::Reading(ref stream) => &**stream.get_ref().get_ref() } } /// Gets a mutable reference to the underlying `NetworkStream`, regardless of the state of the /// `Http11Message`. pub fn get_mut(&mut self) -> &mut (NetworkStream + Send) { - if self.stream.is_some() { - &mut **self.stream.as_mut().unwrap() - } else if self.writer.is_some() { - &mut **self.writer.as_mut().unwrap().get_mut().get_mut() - } else if self.reader.is_some() { - &mut **self.reader.as_mut().unwrap().get_mut().get_mut() - } else { - panic!("Http11Message lost its underlying stream somehow"); + match *self.stream.as_mut() { + Stream::Idle(ref mut stream) => &mut **stream, + Stream::Writing(ref mut stream) => &mut **stream.get_mut().get_mut(), + Stream::Reading(ref mut stream) => &mut **stream.get_mut().get_mut() } } @@ -319,9 +395,7 @@ impl Http11Message { pub fn with_stream(stream: Box) -> Http11Message { Http11Message { method: None, - stream: Some(stream), - writer: None, - reader: None, + stream: Wrapper::new(Stream::new(stream)), } } @@ -329,22 +403,26 @@ impl Http11Message { /// /// TODO It might be sensible to lift this up to the `HttpMessage` trait itself... pub fn flush_outgoing(&mut self) -> ::Result<()> { - match self.writer { - None => return Ok(()), - Some(_) => {}, - }; - - let writer = self.writer.take().unwrap(); - // end() already flushes - let raw = match writer.end() { - Ok(buf) => buf.into_inner().unwrap(), - Err(e) => { - self.writer = Some(e.1); - return Err(From::from(e.0)); - } - }; - self.stream = Some(raw); - Ok(()) + let mut res = Ok(()); + self.stream.map_in_place(|stream| { + let writer = match stream { + Stream::Writing(writer) => writer, + _ => { + res = Ok(()); + return stream; + }, + }; + // end() already flushes + let raw = match writer.end() { + Ok(buf) => buf.into_inner().unwrap(), + Err(e) => { + res = Err(From::from(e.0)); + return Stream::Writing(e.1); + } + }; + Stream::Idle(raw) + }); + res } }