Skip to content

Commit

Permalink
network: Detect early that NotificationOutSubstream was closed by t…
Browse files Browse the repository at this point in the history
…he remote (paritytech#13396)
  • Loading branch information
dmitry-markin authored and ark0f committed Feb 27, 2023
1 parent b0719e6 commit d933077
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 12 deletions.
7 changes: 5 additions & 2 deletions client/network/src/protocol/notifications/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,9 @@ impl ConnectionHandler for NotifsHandler {
// performed before the code paths that can produce `Ready` (with some rare exceptions).
// Importantly, however, the flush is performed *after* notifications are queued with
// `Sink::start_send`.
// Note that we must call `poll_flush` on all substreams and not only on those we
// have called `Sink::start_send` on, because `NotificationsOutSubstream::poll_flush`
// also reports the substream termination (even if no data was written into it).
for protocol_index in 0..self.protocols.len() {
match &mut self.protocols[protocol_index].state {
State::Open { out_substream: out_substream @ Some(_), .. } => {
Expand Down Expand Up @@ -824,7 +827,7 @@ impl ConnectionHandler for NotifsHandler {
State::OpenDesiredByRemote { in_substream, pending_opening } =>
match NotificationsInSubstream::poll_process(Pin::new(in_substream), cx) {
Poll::Pending => {},
Poll::Ready(Ok(void)) => match void {},
Poll::Ready(Ok(())) => {},
Poll::Ready(Err(_)) => {
self.protocols[protocol_index].state =
State::Closed { pending_opening: *pending_opening };
Expand All @@ -840,7 +843,7 @@ impl ConnectionHandler for NotifsHandler {
cx,
) {
Poll::Pending => {},
Poll::Ready(Ok(void)) => match void {},
Poll::Ready(Ok(())) => {},
Poll::Ready(Err(_)) => *in_substream = None,
},
}
Expand Down
132 changes: 122 additions & 10 deletions client/network/src/protocol/notifications/upgrade/notifications.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ use libp2p::core::{upgrade, InboundUpgrade, OutboundUpgrade, UpgradeInfo};
use log::{error, warn};
use sc_network_common::protocol::ProtocolName;
use std::{
convert::Infallible,
io, mem,
pin::Pin,
task::{Context, Poll},
Expand Down Expand Up @@ -221,10 +220,7 @@ where

/// Equivalent to `Stream::poll_next`, except that it only drives the handshake and is
/// guaranteed to not generate any notification.
pub fn poll_process(
self: Pin<&mut Self>,
cx: &mut Context,
) -> Poll<Result<Infallible, io::Error>> {
pub fn poll_process(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
let mut this = self.project();

loop {
Expand All @@ -246,8 +242,10 @@ where
},
NotificationsInSubstreamHandshake::Flush => {
match Sink::poll_flush(this.socket.as_mut(), cx)? {
Poll::Ready(()) =>
*this.handshake = NotificationsInSubstreamHandshake::Sent,
Poll::Ready(()) => {
*this.handshake = NotificationsInSubstreamHandshake::Sent;
return Poll::Ready(Ok(()))
},
Poll::Pending => {
*this.handshake = NotificationsInSubstreamHandshake::Flush;
return Poll::Pending
Expand All @@ -260,7 +258,7 @@ where
st @ NotificationsInSubstreamHandshake::ClosingInResponseToRemote |
st @ NotificationsInSubstreamHandshake::BothSidesClosed => {
*this.handshake = st;
return Poll::Pending
return Poll::Ready(Ok(()))
},
}
}
Expand Down Expand Up @@ -443,6 +441,21 @@ where

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
let mut this = self.project();

// `Sink::poll_flush` does not expose stream closed error until we write something into
// the stream, so the code below makes sure we detect that the substream was closed
// even if we don't write anything into it.
match Stream::poll_next(this.socket.as_mut(), cx) {
Poll::Pending => {},
Poll::Ready(Some(_)) => {
error!(
target: "sub-libp2p",
"Unexpected incoming data in `NotificationsOutSubstream`",
);
},
Poll::Ready(None) => return Poll::Ready(Err(NotificationsOutError::Terminated)),
}

Sink::poll_flush(this.socket.as_mut(), cx).map_err(NotificationsOutError::Io)
}

Expand Down Expand Up @@ -492,13 +505,21 @@ pub enum NotificationsOutError {
/// I/O error on the substream.
#[error(transparent)]
Io(#[from] io::Error),

/// End of incoming data detected on out substream.
#[error("substream was closed/reset")]
Terminated,
}

#[cfg(test)]
mod tests {
use super::{NotificationsIn, NotificationsInOpen, NotificationsOut, NotificationsOutOpen};
use futures::{channel::oneshot, prelude::*};
use super::{
NotificationsIn, NotificationsInOpen, NotificationsOut, NotificationsOutError,
NotificationsOutOpen,
};
use futures::{channel::oneshot, future, prelude::*};
use libp2p::core::upgrade;
use std::{pin::Pin, task::Poll};
use tokio::net::{TcpListener, TcpStream};
use tokio_util::compat::TokioAsyncReadCompatExt;

Expand Down Expand Up @@ -691,4 +712,95 @@ mod tests {

client.await.unwrap();
}

#[tokio::test]
async fn send_handshake_without_polling_for_incoming_data() {
const PROTO_NAME: &str = "/test/proto/1";
let (listener_addr_tx, listener_addr_rx) = oneshot::channel();

let client = tokio::spawn(async move {
let socket = TcpStream::connect(listener_addr_rx.await.unwrap()).await.unwrap();
let NotificationsOutOpen { handshake, .. } = upgrade::apply_outbound(
socket.compat(),
NotificationsOut::new(PROTO_NAME, Vec::new(), &b"initial message"[..], 1024 * 1024),
upgrade::Version::V1,
)
.await
.unwrap();

assert_eq!(handshake, b"hello world");
});

let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
listener_addr_tx.send(listener.local_addr().unwrap()).unwrap();

let (socket, _) = listener.accept().await.unwrap();
let NotificationsInOpen { handshake, mut substream, .. } = upgrade::apply_inbound(
socket.compat(),
NotificationsIn::new(PROTO_NAME, Vec::new(), 1024 * 1024),
)
.await
.unwrap();

assert_eq!(handshake, b"initial message");
substream.send_handshake(&b"hello world"[..]);

// Actually send the handshake.
future::poll_fn(|cx| Pin::new(&mut substream).poll_process(cx)).await.unwrap();

client.await.unwrap();
}

#[tokio::test]
async fn can_detect_dropped_out_substream_without_writing_data() {
const PROTO_NAME: &str = "/test/proto/1";
let (listener_addr_tx, listener_addr_rx) = oneshot::channel();

let client = tokio::spawn(async move {
let socket = TcpStream::connect(listener_addr_rx.await.unwrap()).await.unwrap();
let NotificationsOutOpen { handshake, mut substream, .. } = upgrade::apply_outbound(
socket.compat(),
NotificationsOut::new(PROTO_NAME, Vec::new(), &b"initial message"[..], 1024 * 1024),
upgrade::Version::V1,
)
.await
.unwrap();

assert_eq!(handshake, b"hello world");

future::poll_fn(|cx| match Pin::new(&mut substream).poll_flush(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Ok(())) => {
cx.waker().wake_by_ref();
Poll::Pending
},
Poll::Ready(Err(e)) => {
assert!(matches!(e, NotificationsOutError::Terminated));
Poll::Ready(())
},
})
.await;
});

let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
listener_addr_tx.send(listener.local_addr().unwrap()).unwrap();

let (socket, _) = listener.accept().await.unwrap();
let NotificationsInOpen { handshake, mut substream, .. } = upgrade::apply_inbound(
socket.compat(),
NotificationsIn::new(PROTO_NAME, Vec::new(), 1024 * 1024),
)
.await
.unwrap();

assert_eq!(handshake, b"initial message");

// Send the handhsake.
substream.send_handshake(&b"hello world"[..]);
future::poll_fn(|cx| Pin::new(&mut substream).poll_process(cx)).await.unwrap();

drop(substream);

client.await.unwrap();
}
}

0 comments on commit d933077

Please sign in to comment.