Skip to content

Commit

Permalink
fix(http1): only send 100 Continue if request body is polled
Browse files Browse the repository at this point in the history
Before, if a client request included an `Expect: 100-continue` header,
the `100 Continue` response was sent immediately. However, this is
problematic if the service is going to reply with some 4xx status code
and reject the body.

This change delays the automatic sending of the `100 Continue` status
until the service has call `poll_data` on the request body once.
  • Loading branch information
seanmonstar committed Jan 29, 2020
1 parent a354580 commit c4bb4db
Show file tree
Hide file tree
Showing 7 changed files with 332 additions and 39 deletions.
173 changes: 151 additions & 22 deletions src/body/body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use futures_util::TryStreamExt;
use http::HeaderMap;
use http_body::{Body as HttpBody, SizeHint};

use crate::common::{task, Future, Never, Pin, Poll};
use crate::common::{task, watch, Future, Never, Pin, Poll};
use crate::proto::DecodedLength;
use crate::upgrade::OnUpgrade;

Expand All @@ -33,7 +33,7 @@ enum Kind {
Once(Option<Bytes>),
Chan {
content_length: DecodedLength,
abort_rx: oneshot::Receiver<()>,
want_tx: watch::Sender,
rx: mpsc::Receiver<Result<Bytes, crate::Error>>,
},
H2 {
Expand Down Expand Up @@ -79,12 +79,14 @@ enum DelayEof {
/// Useful when wanting to stream chunks from another thread. See
/// [`Body::channel`](Body::channel) for more.
#[must_use = "Sender does nothing unless sent on"]
#[derive(Debug)]
pub struct Sender {
abort_tx: oneshot::Sender<()>,
want_rx: watch::Receiver,
tx: BodySender,
}

const WANT_PENDING: usize = 1;
const WANT_READY: usize = 2;

impl Body {
/// Create an empty `Body` stream.
///
Expand All @@ -106,17 +108,22 @@ impl Body {
/// Useful when wanting to stream chunks from another thread.
#[inline]
pub fn channel() -> (Sender, Body) {
Self::new_channel(DecodedLength::CHUNKED)
Self::new_channel(DecodedLength::CHUNKED, /*wanter =*/ false)
}

pub(crate) fn new_channel(content_length: DecodedLength) -> (Sender, Body) {
pub(crate) fn new_channel(content_length: DecodedLength, wanter: bool) -> (Sender, Body) {
let (tx, rx) = mpsc::channel(0);
let (abort_tx, abort_rx) = oneshot::channel();

let tx = Sender { abort_tx, tx };
// If wanter is true, `Sender::poll_ready()` won't becoming ready
// until the `Body` has been polled for data once.
let want = if wanter { WANT_PENDING } else { WANT_READY };

let (want_tx, want_rx) = watch::channel(want);

let tx = Sender { want_rx, tx };
let rx = Body::new(Kind::Chan {
content_length,
abort_rx,
want_tx,
rx,
});

Expand Down Expand Up @@ -236,11 +243,9 @@ impl Body {
Kind::Chan {
content_length: ref mut len,
ref mut rx,
ref mut abort_rx,
ref mut want_tx,
} => {
if let Poll::Ready(Ok(())) = Pin::new(abort_rx).poll(cx) {
return Poll::Ready(Some(Err(crate::Error::new_body_write_aborted())));
}
want_tx.send(WANT_READY);

match ready!(Pin::new(rx).poll_next(cx)?) {
Some(chunk) => {
Expand Down Expand Up @@ -460,19 +465,29 @@ impl From<Cow<'static, str>> for Body {
impl Sender {
/// Check to see if this `Sender` can send more data.
pub fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> {
match self.abort_tx.poll_canceled(cx) {
Poll::Ready(()) => return Poll::Ready(Err(crate::Error::new_closed())),
Poll::Pending => (), // fallthrough
}

// Check if the receiver end has tried polling for the body yet
ready!(self.poll_want(cx)?);
self.tx
.poll_ready(cx)
.map_err(|_| crate::Error::new_closed())
}

fn poll_want(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> {
match self.want_rx.load(cx) {
WANT_READY => Poll::Ready(Ok(())),
WANT_PENDING => Poll::Pending,
watch::CLOSED => Poll::Ready(Err(crate::Error::new_closed())),
unexpected => unreachable!("want_rx value: {}", unexpected),
}
}

async fn ready(&mut self) -> crate::Result<()> {
futures_util::future::poll_fn(|cx| self.poll_ready(cx)).await
}

/// Send data on this channel when it is ready.
pub async fn send_data(&mut self, chunk: Bytes) -> crate::Result<()> {
futures_util::future::poll_fn(|cx| self.poll_ready(cx)).await?;
self.ready().await?;
self.tx
.try_send(Ok(chunk))
.map_err(|_| crate::Error::new_closed())
Expand All @@ -498,20 +513,41 @@ impl Sender {

/// Aborts the body in an abnormal fashion.
pub fn abort(self) {
// TODO(sean): this can just be `self.tx.clone().try_send()`
let _ = self.abort_tx.send(());
let _ = self
.tx
// clone so the send works even if buffer is full
.clone()
.try_send(Err(crate::Error::new_body_write_aborted()));
}

pub(crate) fn send_error(&mut self, err: crate::Error) {
let _ = self.tx.try_send(Err(err));
}
}

impl fmt::Debug for Sender {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
#[derive(Debug)]
struct Open;
#[derive(Debug)]
struct Closed;

let mut builder = f.debug_tuple("Sender");
match self.want_rx.peek() {
watch::CLOSED => builder.field(&Closed),
_ => builder.field(&Open),
};

builder.finish()
}
}

#[cfg(test)]
mod tests {
use std::mem;
use std::task::Poll;

use super::{Body, Sender};
use super::{Body, DecodedLength, HttpBody, Sender};

#[test]
fn test_size_of() {
Expand Down Expand Up @@ -541,4 +577,97 @@ mod tests {
"Option<Sender>"
);
}

#[tokio::test]
async fn channel_abort() {
let (tx, mut rx) = Body::channel();

tx.abort();

let err = rx.data().await.unwrap().unwrap_err();
assert!(err.is_body_write_aborted(), "{:?}", err);
}

#[tokio::test]
async fn channel_abort_when_buffer_is_full() {
let (mut tx, mut rx) = Body::channel();

tx.try_send_data("chunk 1".into()).expect("send 1");
// buffer is full, but can still send abort
tx.abort();

let chunk1 = rx.data().await.expect("item 1").expect("chunk 1");
assert_eq!(chunk1, "chunk 1");

let err = rx.data().await.unwrap().unwrap_err();
assert!(err.is_body_write_aborted(), "{:?}", err);
}

#[test]
fn channel_buffers_one() {
let (mut tx, _rx) = Body::channel();

tx.try_send_data("chunk 1".into()).expect("send 1");

// buffer is now full
let chunk2 = tx.try_send_data("chunk 2".into()).expect_err("send 2");
assert_eq!(chunk2, "chunk 2");
}

#[tokio::test]
async fn channel_empty() {
let (_, mut rx) = Body::channel();

assert!(rx.data().await.is_none());
}

#[test]
fn channel_ready() {
let (mut tx, _rx) = Body::new_channel(DecodedLength::CHUNKED, /*wanter = */ false);

let mut tx_ready = tokio_test::task::spawn(tx.ready());

assert!(tx_ready.poll().is_ready(), "tx is ready immediately");
}

#[test]
fn channel_wanter() {
let (mut tx, mut rx) = Body::new_channel(DecodedLength::CHUNKED, /*wanter = */ true);

let mut tx_ready = tokio_test::task::spawn(tx.ready());
let mut rx_data = tokio_test::task::spawn(rx.data());

assert!(
tx_ready.poll().is_pending(),
"tx isn't ready before rx has been polled"
);

assert!(rx_data.poll().is_pending(), "poll rx.data");
assert!(tx_ready.is_woken(), "rx poll wakes tx");

assert!(
tx_ready.poll().is_ready(),
"tx is ready after rx has been polled"
);
}

#[test]
fn channel_notices_closure() {
let (mut tx, rx) = Body::new_channel(DecodedLength::CHUNKED, /*wanter = */ true);

let mut tx_ready = tokio_test::task::spawn(tx.ready());

assert!(
tx_ready.poll().is_pending(),
"tx isn't ready before rx has been polled"
);

drop(rx);
assert!(tx_ready.is_woken(), "dropping rx wakes tx");

match tx_ready.poll() {
Poll::Ready(Err(ref e)) if e.is_closed() => (),
unexpected => panic!("tx poll ready unexpected: {:?}", unexpected),
}
}
}
1 change: 1 addition & 0 deletions src/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pub(crate) mod io;
mod lazy;
mod never;
pub(crate) mod task;
pub(crate) mod watch;

pub use self::exec::Executor;
pub(crate) use self::exec::{BoxSendFuture, Exec};
Expand Down
73 changes: 73 additions & 0 deletions src/common/watch.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
//! An SPSC broadcast channel.
//!
//! - The value can only be a `usize`.
//! - The consumer is only notified if the value is different.
//! - The value `0` is reserved for closed.
use futures_util::task::AtomicWaker;
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
};
use std::task;

type Value = usize;

pub(crate) const CLOSED: usize = 0;

pub(crate) fn channel(initial: Value) -> (Sender, Receiver) {
debug_assert!(
initial != CLOSED,
"watch::channel initial state of 0 is reserved"
);

let shared = Arc::new(Shared {
value: AtomicUsize::new(initial),
waker: AtomicWaker::new(),
});

(
Sender {
shared: shared.clone(),
},
Receiver { shared },
)
}

pub(crate) struct Sender {
shared: Arc<Shared>,
}

pub(crate) struct Receiver {
shared: Arc<Shared>,
}

struct Shared {
value: AtomicUsize,
waker: AtomicWaker,
}

impl Sender {
pub(crate) fn send(&mut self, value: Value) {
if self.shared.value.swap(value, Ordering::SeqCst) != value {
self.shared.waker.wake();
}
}
}

impl Drop for Sender {
fn drop(&mut self) {
self.send(CLOSED);
}
}

impl Receiver {
pub(crate) fn load(&mut self, cx: &mut task::Context<'_>) -> Value {
self.shared.waker.register(cx.waker());
self.shared.value.load(Ordering::SeqCst)
}

pub(crate) fn peek(&self) -> Value {
self.shared.value.load(Ordering::Relaxed)
}
}
Loading

0 comments on commit c4bb4db

Please sign in to comment.