Skip to content

Commit

Permalink
require &mut self for accept
Browse files Browse the repository at this point in the history
  • Loading branch information
rklaehn committed Dec 10, 2024
1 parent 0e982c2 commit 5c7c6f7
Show file tree
Hide file tree
Showing 19 changed files with 31 additions and 33 deletions.
2 changes: 1 addition & 1 deletion examples/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ async fn main() -> anyhow::Result<()> {
let fs = Fs;
let (server, client) = quic_rpc::transport::flume::channel(1);
let client = RpcClient::<IoService, _>::new(client);
let server = RpcServer::new(server);
let mut server = RpcServer::new(server);
let handle = tokio::task::spawn(async move {
for _ in 0..1 {
let (req, chan) = server.accept().await?.read_first().await?;
Expand Down
4 changes: 2 additions & 2 deletions examples/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ async fn main() -> anyhow::Result<()> {
async fn server_future<C: Listener<StoreService>>(
server: RpcServer<StoreService, C>,
) -> result::Result<(), RpcServerError<C>> {
let s = server;
let mut s = server;
let store = Store;
loop {
let (req, chan) = s.accept().await?.read_first().await?;
Expand Down Expand Up @@ -239,7 +239,7 @@ async fn _main_unsugared() -> anyhow::Result<()> {
type Req = u64;
type Res = String;
}
let (server, client) = flume::channel::<u64, String>(1);
let (mut server, client) = flume::channel::<u64, String>(1);
let to_string_service = tokio::spawn(async move {
let (mut send, mut recv) = server.accept().await?;
while let Some(item) = recv.next().await {
Expand Down
6 changes: 3 additions & 3 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ impl<S: Service, C: Listener<S>> Accepting<S, C> {
impl<S: Service, C: Listener<S>> RpcServer<S, C> {
/// Accepts a new channel from a client. The result is an [Accepting] object that
/// can be used to read the first request.
pub async fn accept(&self) -> result::Result<Accepting<S, C>, RpcServerError<C>> {
pub async fn accept(&mut self) -> result::Result<Accepting<S, C>, RpcServerError<C>> {
let (send, recv) = self.source.accept().await.map_err(RpcServerError::Accept)?;
Ok(Accepting {
send,
Expand All @@ -211,7 +211,7 @@ impl<S: Service, C: Listener<S>> RpcServer<S, C> {
/// Each request will be handled in a separate task.
///
/// It is the caller's responsibility to poll the returned future to drive the server.
pub async fn accept_loop<Fun, Fut, E>(self, handler: Fun)
pub async fn accept_loop<Fun, Fut, E>(mut self, handler: Fun)
where
S: Service,
C: Listener<S>,
Expand Down Expand Up @@ -453,7 +453,7 @@ where
F: FnMut(RpcChannel<S, C>, S::Req, T) -> Fut + Send + 'static,
Fut: Future<Output = Result<(), RpcServerError<C>>> + Send + 'static,
{
let server: RpcServer<S, C> = RpcServer::<S, C>::new(conn);
let mut server: RpcServer<S, C> = RpcServer::<S, C>::new(conn);
loop {
let (req, chan) = server.accept().await?.read_first().await?;
let target = target.clone();
Expand Down
12 changes: 6 additions & 6 deletions src/transport/boxed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ impl<In: RpcMessage, Out: RpcMessage> StreamTypes for BoxedStreamTypes<In, Out>
/// A boxable listener
pub trait BoxableListener<In: RpcMessage, Out: RpcMessage>: Debug + Send + Sync + 'static {
/// Accept a channel from a remote client
fn accept_bi_boxed(&self) -> AcceptFuture<In, Out>;
fn accept_bi_boxed(&mut self) -> AcceptFuture<In, Out>;

/// Get the local address
fn local_addr(&self) -> &[super::LocalAddr];
Expand Down Expand Up @@ -324,7 +324,7 @@ impl<In: RpcMessage, Out: RpcMessage> ConnectionErrors for BoxedListener<In, Out

impl<In: RpcMessage, Out: RpcMessage> super::Listener for BoxedListener<In, Out> {
fn accept(
&self,
&mut self,
) -> impl Future<Output = Result<(Self::SendSink, Self::RecvStream), Self::AcceptError>> + Send
{
self.0.accept_bi_boxed()
Expand Down Expand Up @@ -369,7 +369,7 @@ impl<In: RpcMessage, Out: RpcMessage> BoxableConnector<In, Out>
impl<In: RpcMessage, Out: RpcMessage> BoxableListener<In, Out>
for super::quinn::QuinnListener<In, Out>
{
fn accept_bi_boxed(&self) -> AcceptFuture<In, Out> {
fn accept_bi_boxed(&mut self) -> AcceptFuture<In, Out> {
let f = async move {
let (send, recv) = super::Listener::accept(self).await?;
let send = send.sink_map_err(anyhow::Error::from);
Expand Down Expand Up @@ -409,7 +409,7 @@ impl<In: RpcMessage, Out: RpcMessage> BoxableConnector<In, Out>
impl<In: RpcMessage, Out: RpcMessage> BoxableListener<In, Out>
for super::iroh::IrohListener<In, Out>
{
fn accept_bi_boxed(&self) -> AcceptFuture<In, Out> {
fn accept_bi_boxed(&mut self) -> AcceptFuture<In, Out> {
let f = async move {
let (send, recv) = super::Listener::accept(self).await?;
let send = send.sink_map_err(anyhow::Error::from);
Expand Down Expand Up @@ -441,7 +441,7 @@ impl<In: RpcMessage, Out: RpcMessage> BoxableConnector<In, Out>
impl<In: RpcMessage, Out: RpcMessage> BoxableListener<In, Out>
for super::flume::FlumeListener<In, Out>
{
fn accept_bi_boxed(&self) -> AcceptFuture<In, Out> {
fn accept_bi_boxed(&mut self) -> AcceptFuture<In, Out> {
AcceptFuture::direct(super::Listener::accept(self))
}

Expand Down Expand Up @@ -499,7 +499,7 @@ mod tests {
use crate::transport::{Connector, Listener};

let (server, client) = crate::transport::flume::channel(1);
let server = super::BoxedListener::new(server);
let mut server = super::BoxedListener::new(server);
let client = super::BoxedConnector::new(client);
// spawn echo server
tokio::spawn(async move {
Expand Down
6 changes: 3 additions & 3 deletions src/transport/combined.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,17 +241,17 @@ impl<A: Listener, B: Listener<In = A::In, Out = A::Out>> StreamTypes for Combine
}

impl<A: Listener, B: Listener<In = A::In, Out = A::Out>> Listener for CombinedListener<A, B> {
async fn accept(&self) -> Result<(Self::SendSink, Self::RecvStream), Self::AcceptError> {
async fn accept(&mut self) -> Result<(Self::SendSink, Self::RecvStream), Self::AcceptError> {
let a_fut = async {
if let Some(a) = &self.a {
if let Some(a) = &mut self.a {
let (send, recv) = a.accept().await.map_err(AcceptError::A)?;
Ok((SendSink::A(send), RecvStream::A(recv)))
} else {
std::future::pending().await
}
};
let b_fut = async {
if let Some(b) = &self.b {
if let Some(b) = &mut self.b {
let (send, recv) = b.accept().await.map_err(AcceptError::B)?;
Ok((SendSink::B(send), RecvStream::B(recv)))
} else {
Expand Down
2 changes: 1 addition & 1 deletion src/transport/flume.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ impl<In: RpcMessage, Out: RpcMessage> StreamTypes for FlumeListener<In, Out> {

impl<In: RpcMessage, Out: RpcMessage> Listener for FlumeListener<In, Out> {
#[allow(refining_impl_trait)]
fn accept(&self) -> AcceptFuture<In, Out> {
fn accept(&mut self) -> AcceptFuture<In, Out> {
AcceptFuture {
wrapped: self.stream.clone().into_recv_async(),
_p: PhantomData,
Expand Down
2 changes: 1 addition & 1 deletion src/transport/hyper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@ impl<In: RpcMessage, Out: RpcMessage> Listener for HyperListener<In, Out> {
&self.local_addr
}

async fn accept(&self) -> Result<(Self::SendSink, Self::RecvStream), AcceptError> {
async fn accept(&mut self) -> Result<(Self::SendSink, Self::RecvStream), AcceptError> {
let (recv, send) = self
.channel
.recv_async()
Expand Down
2 changes: 1 addition & 1 deletion src/transport/iroh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ impl<In: RpcMessage, Out: RpcMessage> StreamTypes for IrohListener<In, Out> {
}

impl<In: RpcMessage, Out: RpcMessage> Listener for IrohListener<In, Out> {
async fn accept(&self) -> Result<(Self::SendSink, Self::RecvStream), AcceptError> {
async fn accept(&mut self) -> Result<(Self::SendSink, Self::RecvStream), AcceptError> {
let (send, recv) = self
.inner
.receiver
Expand Down
2 changes: 1 addition & 1 deletion src/transport/mapped.rs
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ mod tests {
// create a listener / connector pair. Type will be inferred
let (s, c) = crate::transport::flume::channel(32);
// wrap the server in a RpcServer, this is where the service type is specified
let server = RpcServer::<FullService, _>::new(s.clone());
let mut server = RpcServer::<FullService, _>::new(s.clone());
// when using a boxed transport, we can omit the transport type and use the default
let _server_boxed: RpcServer<FullService> = RpcServer::<FullService>::new(s.boxed());
// create a client in a RpcClient, this is where the service type is specified
Expand Down
2 changes: 1 addition & 1 deletion src/transport/misc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ impl<In: RpcMessage, Out: RpcMessage> StreamTypes for DummyListener<In, Out> {
}

impl<In: RpcMessage, Out: RpcMessage> Listener for DummyListener<In, Out> {
async fn accept(&self) -> Result<(Self::SendSink, Self::RecvStream), Self::AcceptError> {
async fn accept(&mut self) -> Result<(Self::SendSink, Self::RecvStream), Self::AcceptError> {
futures_lite::future::pending().await
}

Expand Down
2 changes: 1 addition & 1 deletion src/transport/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ pub trait Listener: StreamTypes {
/// Accept a new typed bidirectional channel on any of the connections we
/// have currently opened.
fn accept(
&self,
&mut self,
) -> impl Future<Output = Result<(Self::SendSink, Self::RecvStream), Self::AcceptError>> + Send;

/// The local addresses this endpoint is bound to.
Expand Down
2 changes: 1 addition & 1 deletion src/transport/quinn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ impl<In: RpcMessage, Out: RpcMessage> StreamTypes for QuinnListener<In, Out> {
}

impl<In: RpcMessage, Out: RpcMessage> Listener for QuinnListener<In, Out> {
async fn accept(&self) -> Result<(Self::SendSink, Self::RecvStream), AcceptError> {
async fn accept(&mut self) -> Result<(Self::SendSink, Self::RecvStream), AcceptError> {
let (send, recv) = self
.inner
.receiver
Expand Down
8 changes: 3 additions & 5 deletions src/transport/tokio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ impl error::Error for RecvError {}
/// Created using [channel].
pub struct MemListener<In: RpcMessage, Out: RpcMessage> {
#[allow(clippy::type_complexity)]
stream: tokio::sync::Mutex<tokio::sync::mpsc::Receiver<(SendSink<Out>, RecvStream<In>)>>,
stream: tokio::sync::mpsc::Receiver<(SendSink<Out>, RecvStream<In>)>,
}

impl<In: RpcMessage, Out: RpcMessage> fmt::Debug for MemListener<In, Out> {
Expand Down Expand Up @@ -183,9 +183,8 @@ impl<In: RpcMessage, Out: RpcMessage> StreamTypes for MemListener<In, Out> {
}

impl<In: RpcMessage, Out: RpcMessage> Listener for MemListener<In, Out> {
async fn accept(&self) -> Result<(Self::SendSink, Self::RecvStream), AcceptError> {
let mut stream = self.stream.lock().await;
match stream.recv().await {
async fn accept(&mut self) -> Result<(Self::SendSink, Self::RecvStream), AcceptError> {
match self.stream.recv().await {
Some((send, recv)) => Ok((send, recv)),
None => Err(AcceptError::RemoteDropped),
}
Expand Down Expand Up @@ -323,6 +322,5 @@ pub fn channel<Req: RpcMessage, Res: RpcMessage>(
buffer: usize,
) -> (MemListener<Req, Res>, MemConnector<Res, Req>) {
let (sink, stream) = tokio::sync::mpsc::channel(buffer);
let stream = tokio::sync::Mutex::new(stream);
(MemListener { stream }, MemConnector { sink })
}
2 changes: 1 addition & 1 deletion tests/flume.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ async fn flume_channel_mapped_bench() -> anyhow::Result<()> {
}
let (server, client) = flume::channel(1);

let server = RpcServer::<OuterService, _>::new(server);
let mut server = RpcServer::<OuterService, _>::new(server);
let server_handle: tokio::task::JoinHandle<Result<(), RpcServerError<_>>> =
tokio::task::spawn(async move {
let service = ComputeService;
Expand Down
2 changes: 1 addition & 1 deletion tests/hyper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ async fn hyper_channel_errors() -> anyhow::Result<()> {
Receiver<result::Result<(), RpcServerError<SC>>>,
) {
let channel = HyperListener::serve(addr).unwrap();
let server = RpcServer::new(channel);
let mut server = RpcServer::new(channel);
let (res_tx, res_rx) = flume::unbounded();
let handle = tokio::spawn(async move {
loop {
Expand Down
2 changes: 1 addition & 1 deletion tests/math.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ impl ComputeService {
count: usize,
) -> result::Result<RpcServer<ComputeService, C>, RpcServerError<C>> {
tracing::info!(%count, "server running");
let s = server;
let mut s = server;
let mut received = 0;
let service = ComputeService;
while received < count {
Expand Down
2 changes: 1 addition & 1 deletion tests/slow_math.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ impl ComputeService {
pub async fn server<C: Listener<ComputeService>>(
server: RpcServer<ComputeService, C>,
) -> result::Result<(), RpcServerError<C>> {
let s = server;
let mut s = server;
let service = ComputeService;
loop {
let (req, chan) = s.accept().await?.read_first().await?;
Expand Down
2 changes: 1 addition & 1 deletion tests/tokio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ async fn tokio_channel_mapped_bench() -> anyhow::Result<()> {
}
let (server, client) = tkio::channel(1);

let server = RpcServer::<OuterService, _>::new(server);
let mut server = RpcServer::<OuterService, _>::new(server);
let server_handle: tokio::task::JoinHandle<Result<(), RpcServerError<_>>> =
tokio::task::spawn(async move {
let service = ComputeService;
Expand Down
2 changes: 1 addition & 1 deletion tests/try.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ async fn try_server_streaming() -> anyhow::Result<()> {
tracing_subscriber::fmt::try_init().ok();
let (server, client) = flume::channel(1);

let server = RpcServer::<TryService, _>::new(server);
let mut server = RpcServer::<TryService, _>::new(server);
let server_handle = tokio::task::spawn(async move {
loop {
let (req, chan) = server.accept().await?.read_first().await?;
Expand Down

0 comments on commit 5c7c6f7

Please sign in to comment.