From 227742221fa7830a14c18becbbc6137d97b57729 Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Tue, 23 Jan 2018 11:33:24 -0800 Subject: [PATCH] fix(client): error on unsupport 101 responses, ignore other 1xx codes --- src/error.rs | 18 ++++--- src/proto/conn.rs | 114 +++++++++++++++++++++------------------- src/proto/h1/parse.rs | 119 +++++++++++++++++++++++++++++------------- src/proto/mod.rs | 2 +- tests/client.rs | 66 +++++++++++++++++++++++ 5 files changed, 222 insertions(+), 97 deletions(-) diff --git a/src/error.rs b/src/error.rs index 32f6291fe3..5a2d448869 100644 --- a/src/error.rs +++ b/src/error.rs @@ -16,6 +16,7 @@ use self::Error::{ Header, Status, Timeout, + Upgrade, Io, TooLarge, Incomplete, @@ -44,6 +45,8 @@ pub enum Error { Status, /// A timeout occurred waiting for an IO event. Timeout, + /// A protocol upgrade was encountered, but not yet supported in hyper. + Upgrade, /// An `io::Error` that occurred while trying to read or write to a network stream. Io(IoError), /// Parsing a field as string failed @@ -76,13 +79,14 @@ impl fmt::Display for Error { impl StdError for Error { fn description(&self) -> &str { match *self { - Method => "Invalid Method specified", - Version => "Invalid HTTP version specified", - Header => "Invalid Header provided", - TooLarge => "Message head is too large", - Status => "Invalid Status provided", - Incomplete => "Message is incomplete", - Timeout => "Timeout", + Method => "invalid Method specified", + Version => "invalid HTTP version specified", + Header => "invalid Header provided", + TooLarge => "message head is too large", + Status => "invalid Status provided", + Incomplete => "message is incomplete", + Timeout => "timeout", + Upgrade => "unsupported protocol upgrade", Uri(ref e) => e.description(), Io(ref e) => e.description(), Utf8(ref e) => e.description(), diff --git a/src/proto/conn.rs b/src/proto/conn.rs index bda987f400..5f572db327 100644 --- a/src/proto/conn.rs +++ b/src/proto/conn.rs @@ -171,65 +171,71 @@ where I: AsyncRead + AsyncWrite, debug_assert!(self.can_read_head()); trace!("Conn::read_head"); - let (version, head) = match self.io.parse::() { - Ok(Async::Ready(head)) => (head.version, head), - Ok(Async::NotReady) => return Ok(Async::NotReady), - Err(e) => { - // If we are currently waiting on a message, then an empty - // message should be reported as an error. If not, it is just - // the connection closing gracefully. - let must_error = self.should_error_on_eof(); - self.state.close_read(); - self.io.consume_leading_lines(); - let was_mid_parse = !self.io.read_buf().is_empty(); - return if was_mid_parse || must_error { - debug!("parse error ({}) with {} bytes", e, self.io.read_buf().len()); - Err(e) - } else { - debug!("read eof"); - Ok(Async::Ready(None)) - }; - } - }; + loop { + let (version, head) = match self.io.parse::() { + Ok(Async::Ready(head)) => (head.version, head), + Ok(Async::NotReady) => return Ok(Async::NotReady), + Err(e) => { + // If we are currently waiting on a message, then an empty + // message should be reported as an error. If not, it is just + // the connection closing gracefully. + let must_error = self.should_error_on_eof(); + self.state.close_read(); + self.io.consume_leading_lines(); + let was_mid_parse = !self.io.read_buf().is_empty(); + return if was_mid_parse || must_error { + debug!("parse error ({}) with {} bytes", e, self.io.read_buf().len()); + Err(e) + } else { + debug!("read eof"); + Ok(Async::Ready(None)) + }; + } + }; - self.state.version = match version { - HttpVersion::Http10 => Version::Http10, - HttpVersion::Http11 => Version::Http11, - _ => { - error!("unimplemented HTTP Version = {:?}", version); - self.state.close_read(); - return Err(::Error::Version); - } - }; + self.state.version = match version { + HttpVersion::Http10 => Version::Http10, + HttpVersion::Http11 => Version::Http11, + _ => { + error!("unimplemented HTTP Version = {:?}", version); + self.state.close_read(); + return Err(::Error::Version); + } + }; - let decoder = match T::decoder(&head, &mut self.state.method) { - Ok(d) => d, - Err(e) => { - debug!("decoder error = {:?}", e); - self.state.close_read(); - return Err(e); - } - }; + let decoder = match T::decoder(&head, &mut self.state.method) { + Ok(Some(d)) => d, + Ok(None) => { + // likely a 1xx message that we can ignore + continue; + } + Err(e) => { + debug!("decoder error = {:?}", e); + self.state.close_read(); + return Err(e); + } + }; - debug!("incoming body is {}", decoder); + debug!("incoming body is {}", decoder); - self.state.busy(); - if head.expecting_continue() { - let msg = b"HTTP/1.1 100 Continue\r\n\r\n"; - self.state.writing = Writing::Continue(Cursor::new(msg)); - } - let wants_keep_alive = head.should_keep_alive(); - self.state.keep_alive &= wants_keep_alive; - let (body, reading) = if decoder.is_eof() { - (false, Reading::KeepAlive) - } else { - (true, Reading::Body(decoder)) - }; - self.state.reading = reading; - if !body { - self.try_keep_alive(); + self.state.busy(); + if head.expecting_continue() { + let msg = b"HTTP/1.1 100 Continue\r\n\r\n"; + self.state.writing = Writing::Continue(Cursor::new(msg)); + } + let wants_keep_alive = head.should_keep_alive(); + self.state.keep_alive &= wants_keep_alive; + let (body, reading) = if decoder.is_eof() { + (false, Reading::KeepAlive) + } else { + (true, Reading::Body(decoder)) + }; + self.state.reading = reading; + if !body { + self.try_keep_alive(); + } + return Ok(Async::Ready(Some((head, body)))); } - Ok(Async::Ready(Some((head, body)))) } pub fn read_body(&mut self) -> Poll, io::Error> { diff --git a/src/proto/h1/parse.rs b/src/proto/h1/parse.rs index 0050394405..3c6363b623 100644 --- a/src/proto/h1/parse.rs +++ b/src/proto/h1/parse.rs @@ -72,7 +72,7 @@ impl Http1Transaction for ServerTransaction { }, len))) } - fn decoder(head: &MessageHead, method: &mut Option) -> ::Result { + fn decoder(head: &MessageHead, method: &mut Option) -> ::Result> { use ::header; *method = Some(head.subject.0.clone()); @@ -91,19 +91,22 @@ impl Http1Transaction for ServerTransaction { // If Transfer-Encoding header is present, and 'chunked' is // not the final encoding, and this is a Request, then it is // mal-formed. A server should respond with 400 Bad Request. - if encodings.last() == Some(&header::Encoding::Chunked) { - Ok(Decoder::chunked()) + if head.version == Http10 { + debug!("HTTP/1.0 has Transfer-Encoding header"); + Err(::Error::Header) + } else if encodings.last() == Some(&header::Encoding::Chunked) { + Ok(Some(Decoder::chunked())) } else { debug!("request with transfer-encoding header, but not chunked, bad request"); Err(::Error::Header) } } else if let Some(&header::ContentLength(len)) = head.headers.get() { - Ok(Decoder::length(len)) + Ok(Some(Decoder::length(len))) } else if head.headers.has::() { debug!("illegal Content-Length: {:?}", head.headers.get_raw("Content-Length")); Err(::Error::Header) } else { - Ok(Decoder::length(0)) + Ok(Some(Decoder::length(0))) } } @@ -225,7 +228,7 @@ impl Http1Transaction for ClientTransaction { }, len))) } - fn decoder(inc: &MessageHead, method: &mut Option) -> ::Result { + fn decoder(inc: &MessageHead, method: &mut Option) -> ::Result> { // According to https://tools.ietf.org/html/rfc7230#section-3.3.3 // 1. HEAD responses, and Status 1xx, 204, and 304 cannot have a body. // 2. Status 2xx to a CONNECT cannot have a body. @@ -235,13 +238,26 @@ impl Http1Transaction for ClientTransaction { // 6. (irrelevant to Response) // 7. Read till EOF. + match inc.subject.0 { + 101 => { + debug!("received 101 upgrade response, not supported"); + return Err(::Error::Upgrade); + }, + 100...199 => { + trace!("ignoring informational response: {}", inc.subject.0); + return Ok(None); + }, + 204 | + 304 => return Ok(Some(Decoder::length(0))), + _ => (), + } match *method { Some(Method::Head) => { - return Ok(Decoder::length(0)); + return Ok(Some(Decoder::length(0))); } Some(Method::Connect) => match inc.subject.0 { 200...299 => { - return Ok(Decoder::length(0)); + return Ok(Some(Decoder::length(0))); }, _ => {}, }, @@ -251,28 +267,25 @@ impl Http1Transaction for ClientTransaction { } } - match inc.subject.0 { - 100...199 | - 204 | - 304 => return Ok(Decoder::length(0)), - _ => (), - } if let Some(&header::TransferEncoding(ref codings)) = inc.headers.get() { - if codings.last() == Some(&header::Encoding::Chunked) { - Ok(Decoder::chunked()) + if inc.version == Http10 { + debug!("HTTP/1.0 has Transfer-Encoding header"); + Err(::Error::Header) + } else if codings.last() == Some(&header::Encoding::Chunked) { + Ok(Some(Decoder::chunked())) } else { trace!("not chunked. read till eof"); - Ok(Decoder::eof()) + Ok(Some(Decoder::eof())) } } else if let Some(&header::ContentLength(len)) = inc.headers.get() { - Ok(Decoder::length(len)) + Ok(Some(Decoder::length(len))) } else if inc.headers.has::() { debug!("illegal Content-Length: {:?}", inc.headers.get_raw("Content-Length")); Err(::Error::Header) } else { trace!("neither Transfer-Encoding nor Content-Length"); - Ok(Decoder::eof()) + Ok(Some(Decoder::eof())) } } @@ -460,24 +473,24 @@ mod tests { let mut head = MessageHead::<::proto::RequestLine>::default(); head.subject.0 = ::Method::Get; - assert_eq!(Decoder::length(0), ServerTransaction::decoder(&head, method).unwrap()); + assert_eq!(Decoder::length(0), ServerTransaction::decoder(&head, method).unwrap().unwrap()); assert_eq!(*method, Some(::Method::Get)); head.subject.0 = ::Method::Post; - assert_eq!(Decoder::length(0), ServerTransaction::decoder(&head, method).unwrap()); + assert_eq!(Decoder::length(0), ServerTransaction::decoder(&head, method).unwrap().unwrap()); assert_eq!(*method, Some(::Method::Post)); head.headers.set(TransferEncoding::chunked()); - assert_eq!(Decoder::chunked(), ServerTransaction::decoder(&head, method).unwrap()); + assert_eq!(Decoder::chunked(), ServerTransaction::decoder(&head, method).unwrap().unwrap()); // transfer-encoding and content-length = chunked head.headers.set(ContentLength(10)); - assert_eq!(Decoder::chunked(), ServerTransaction::decoder(&head, method).unwrap()); + assert_eq!(Decoder::chunked(), ServerTransaction::decoder(&head, method).unwrap().unwrap()); head.headers.remove::(); - assert_eq!(Decoder::length(10), ServerTransaction::decoder(&head, method).unwrap()); + assert_eq!(Decoder::length(10), ServerTransaction::decoder(&head, method).unwrap().unwrap()); head.headers.set_raw("Content-Length", vec![b"5".to_vec(), b"5".to_vec()]); - assert_eq!(Decoder::length(5), ServerTransaction::decoder(&head, method).unwrap()); + assert_eq!(Decoder::length(5), ServerTransaction::decoder(&head, method).unwrap().unwrap()); head.headers.set_raw("Content-Length", vec![b"10".to_vec(), b"11".to_vec()]); ServerTransaction::decoder(&head, method).unwrap_err(); @@ -486,6 +499,21 @@ mod tests { head.headers.set_raw("Transfer-Encoding", "gzip"); ServerTransaction::decoder(&head, method).unwrap_err(); + + + // http/1.0 + head.version = ::HttpVersion::Http10; + head.headers.clear(); + + // 1.0 requests can only have bodies if content-length is set + assert_eq!(Decoder::length(0), ServerTransaction::decoder(&head, method).unwrap().unwrap()); + + head.headers.set(TransferEncoding::chunked()); + ServerTransaction::decoder(&head, method).unwrap_err(); + head.headers.remove::(); + + head.headers.set(ContentLength(15)); + assert_eq!(Decoder::length(15), ServerTransaction::decoder(&head, method).unwrap().unwrap()); } #[test] @@ -496,43 +524,64 @@ mod tests { let mut head = MessageHead::<::proto::RawStatus>::default(); head.subject.0 = 204; - assert_eq!(Decoder::length(0), ClientTransaction::decoder(&head, method).unwrap()); + assert_eq!(Decoder::length(0), ClientTransaction::decoder(&head, method).unwrap().unwrap()); head.subject.0 = 304; - assert_eq!(Decoder::length(0), ClientTransaction::decoder(&head, method).unwrap()); + assert_eq!(Decoder::length(0), ClientTransaction::decoder(&head, method).unwrap().unwrap()); head.subject.0 = 200; - assert_eq!(Decoder::eof(), ClientTransaction::decoder(&head, method).unwrap()); + assert_eq!(Decoder::eof(), ClientTransaction::decoder(&head, method).unwrap().unwrap()); *method = Some(::Method::Head); - assert_eq!(Decoder::length(0), ClientTransaction::decoder(&head, method).unwrap()); + assert_eq!(Decoder::length(0), ClientTransaction::decoder(&head, method).unwrap().unwrap()); *method = Some(::Method::Connect); - assert_eq!(Decoder::length(0), ClientTransaction::decoder(&head, method).unwrap()); + assert_eq!(Decoder::length(0), ClientTransaction::decoder(&head, method).unwrap().unwrap()); // CONNECT receiving non 200 can have a body head.subject.0 = 404; head.headers.set(ContentLength(10)); - assert_eq!(Decoder::length(10), ClientTransaction::decoder(&head, method).unwrap()); + assert_eq!(Decoder::length(10), ClientTransaction::decoder(&head, method).unwrap().unwrap()); head.headers.remove::(); *method = Some(::Method::Get); head.headers.set(TransferEncoding::chunked()); - assert_eq!(Decoder::chunked(), ClientTransaction::decoder(&head, method).unwrap()); + assert_eq!(Decoder::chunked(), ClientTransaction::decoder(&head, method).unwrap().unwrap()); // transfer-encoding and content-length = chunked head.headers.set(ContentLength(10)); - assert_eq!(Decoder::chunked(), ClientTransaction::decoder(&head, method).unwrap()); + assert_eq!(Decoder::chunked(), ClientTransaction::decoder(&head, method).unwrap().unwrap()); head.headers.remove::(); - assert_eq!(Decoder::length(10), ClientTransaction::decoder(&head, method).unwrap()); + assert_eq!(Decoder::length(10), ClientTransaction::decoder(&head, method).unwrap().unwrap()); head.headers.set_raw("Content-Length", vec![b"5".to_vec(), b"5".to_vec()]); - assert_eq!(Decoder::length(5), ClientTransaction::decoder(&head, method).unwrap()); + assert_eq!(Decoder::length(5), ClientTransaction::decoder(&head, method).unwrap().unwrap()); head.headers.set_raw("Content-Length", vec![b"10".to_vec(), b"11".to_vec()]); ClientTransaction::decoder(&head, method).unwrap_err(); + head.headers.clear(); + + // 1xx status codes + head.subject.0 = 100; + assert!(ClientTransaction::decoder(&head, method).unwrap().is_none()); + + head.subject.0 = 103; + assert!(ClientTransaction::decoder(&head, method).unwrap().is_none()); + + // 101 upgrade not supported yet + head.subject.0 = 101; + ClientTransaction::decoder(&head, method).unwrap_err(); + head.subject.0 = 200; + + // http/1.0 + head.version = ::HttpVersion::Http10; + + assert_eq!(Decoder::eof(), ClientTransaction::decoder(&head, method).unwrap().unwrap()); + + head.headers.set(TransferEncoding::chunked()); + ClientTransaction::decoder(&head, method).unwrap_err(); } #[cfg(feature = "nightly")] diff --git a/src/proto/mod.rs b/src/proto/mod.rs index b165492855..5bd8ead5b3 100644 --- a/src/proto/mod.rs +++ b/src/proto/mod.rs @@ -147,7 +147,7 @@ pub trait Http1Transaction { type Incoming; type Outgoing: Default; fn parse(bytes: &mut BytesMut) -> ParseResult; - fn decoder(head: &MessageHead, method: &mut Option<::Method>) -> ::Result; + fn decoder(head: &MessageHead, method: &mut Option<::Method>) -> ::Result>; fn encode(head: MessageHead, has_body: bool, method: &mut Option, dst: &mut Vec) -> h1::Encoder; fn should_error_on_parse_eof() -> bool; diff --git a/tests/client.rs b/tests/client.rs index d77cba32f5..c1f4f2e9d3 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -458,6 +458,72 @@ test! { } +test! { + name: client_100_continue, + + server: + expected: "\ + POST /continue HTTP/1.1\r\n\ + Host: {addr}\r\n\ + Content-Length: 7\r\n\ + \r\n\ + foo bar\ + ", + reply: "\ + HTTP/1.1 100 Continue\r\n\ + \r\n\ + HTTP/1.1 200 OK\r\n\ + Content-Length: 0\r\n\ + \r\n\ + ", + + client: + request: + method: Post, + url: "http://{addr}/continue", + headers: [ + ContentLength(7), + ], + body: Some("foo bar"), + proxy: false, + response: + status: Ok, + headers: [], + body: None, +} + + +test! { + name: client_101_upgrade, + + server: + expected: "\ + GET /upgrade HTTP/1.1\r\n\ + Host: {addr}\r\n\ + \r\n\ + ", + reply: "\ + HTTP/1.1 101 Switching Protocols\r\n\ + Upgrade: websocket\r\n\ + Connection: upgrade\r\n\ + \r\n\ + ", + + client: + request: + method: Get, + url: "http://{addr}/upgrade", + headers: [], + body: None, + proxy: false, + error: |err| match err { + &hyper::Error::Upgrade => true, + _ => false, + }, + +} + + #[test] fn client_keep_alive() { let server = TcpListener::bind("127.0.0.1:0").unwrap();