Skip to content

Commit

Permalink
feat(socketio/ns): improve emitter trait
Browse files Browse the repository at this point in the history
  • Loading branch information
Totodore committed Dec 27, 2024
1 parent f44d031 commit 50a36fe
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 87 deletions.
28 changes: 13 additions & 15 deletions crates/socketioxide-core/src/adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,16 +171,16 @@ pub trait SocketEmitter: Send + Sync + 'static {
/// Get all the socket ids in the namespace.
fn get_all_sids(&self) -> Vec<Sid>;
/// Send data to the list of socket ids.
fn send_many(&self, sids: Vec<Sid>, data: Value) -> Result<(), Vec<SocketError>>;
fn send_many(&self, sids: &[Sid], data: Value) -> Result<(), Vec<SocketError>>;
/// Send data to the list of socket ids and get a stream of acks.
fn send_many_with_ack(
&self,
sids: Vec<Sid>,
sids: &[Sid],
packet: Packet,
timeout: Option<Duration>,
) -> Self::AckStream;
/// Disconnect all the sockets in the list.
fn disconnect_many(&self, sid: Vec<Sid>) -> Result<(), Vec<SocketError>>;
fn disconnect_many(&self, sid: &[Sid]) -> Result<(), Vec<SocketError>>;
/// Get the path of the namespace.
fn path(&self) -> &Str;
/// Get the parser of the namespace.
Expand Down Expand Up @@ -373,7 +373,7 @@ impl<E: SocketEmitter> CoreLocalAdapter<E> {
}

let data = self.sockets.parser().encode(packet);
self.sockets.send_many(sids, data)
self.sockets.send_many(&sids, data)
}

/// Broadcasts the packet to the sockets that match the [`BroadcastOptions`] and return a stream of ack responses.
Expand All @@ -390,13 +390,13 @@ impl<E: SocketEmitter> CoreLocalAdapter<E> {

let count = sids.len() as u32;
// We cannot pre-serialize the packet because we need to change the ack id.
let stream = self.sockets.send_many_with_ack(sids, packet, timeout);
let stream = self.sockets.send_many_with_ack(&sids, packet, timeout);
(stream, count)
}

/// Returns the sockets ids that match the [`BroadcastOptions`].
pub fn sockets(&self, opts: BroadcastOptions) -> Vec<Sid> {
self.apply_opts(opts)
self.apply_opts(opts).into_vec()
}

//TODO: make this operation O(1)
Expand Down Expand Up @@ -429,7 +429,7 @@ impl<E: SocketEmitter> CoreLocalAdapter<E> {
/// Disconnects the sockets that match the [`BroadcastOptions`].
pub fn disconnect_socket(&self, opts: BroadcastOptions) -> Result<(), Vec<SocketError>> {
let sids = self.apply_opts(opts);
self.sockets.disconnect_many(sids)
self.sockets.disconnect_many(&sids)
}

/// Returns all the rooms for this adapter.
Expand All @@ -450,33 +450,31 @@ impl<E: SocketEmitter> CoreLocalAdapter<E> {

impl<E: SocketEmitter> CoreLocalAdapter<E> {
/// Applies the given `opts` and return the sockets that match.
fn apply_opts(&self, opts: BroadcastOptions) -> Vec<Sid> {
fn apply_opts(&self, opts: BroadcastOptions) -> SmallVec<[Sid; 16]> {
let is_broadcast = opts.has_flag(BroadcastFlags::Broadcast);
let rooms = opts.rooms;

let except = self.get_except_sids(&opts.except);
let is_socket_current = |id| opts.sid.map(|s| s != id).unwrap_or(true);
if !rooms.is_empty() {
let rooms_map = self.rooms.read().unwrap();
rooms
.iter()
.filter_map(|room| rooms_map.get(room))
.flatten()
.copied()
.filter(|id| {
!except.contains(id)
&& (!is_broadcast || opts.sid.map(|s| &s != id).unwrap_or(true))
})
.filter(|id| !except.contains(id) && (!is_broadcast || is_socket_current(*id)))
.collect()
} else if is_broadcast {
self.sockets
.get_all_sids()
.into_iter()
.filter(|id| !except.contains(id) && opts.sid.map(|s| &s != id).unwrap_or(true))
.filter(|id| !except.contains(id) && is_socket_current(*id))
.collect()
} else if let Some(id) = opts.sid {
vec![id]
smallvec::smallvec![id]
} else {
vec![]
smallvec::smallvec![]
}
}

Expand Down
29 changes: 12 additions & 17 deletions crates/socketioxide/src/ack.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,19 +144,13 @@ impl AckInnerStream {
///
/// The [`AckInnerStream`] will wait for the default timeout specified in the config
/// (5s by default) if no custom timeout is specified.
pub fn broadcast<A: Adapter>(
pub fn broadcast<'a, A: Adapter>(
packet: Packet,
sockets: Vec<Arc<Socket<A>>>,
duration: Option<Duration>,
sockets: impl Iterator<Item = &'a Arc<Socket<A>>>,
duration: Duration,
) -> Self {
let rxs = FuturesUnordered::new();

if sockets.is_empty() {
return AckInnerStream::Stream { rxs };
}

let duration =
duration.unwrap_or_else(|| sockets.first().unwrap().get_io().config().ack_timeout);
for socket in sockets {
let rx = socket.send_with_ack(packet.clone());
rxs.push(AckResultWithId {
Expand Down Expand Up @@ -312,16 +306,17 @@ mod test {
Self::new(val, Parser::default())
}
}
const TIMEOUT: Duration = Duration::from_secs(5);

#[tokio::test]
async fn broadcast_ack() {
let socket = create_socket();
let socket2 = create_socket();
let mut packet = get_packet();
packet.inner.set_ack_id(1);
let socks = vec![socket.clone().into(), socket2.clone().into()];
let socks = vec![&socket, &socket2];
let stream: AckStream<String, LocalAdapter> =
AckInnerStream::broadcast(packet, socks, None).into();
AckInnerStream::broadcast(packet, socks.into_iter(), TIMEOUT).into();

let res_packet = Packet::ack("test", value("test"), 1);
socket.recv(res_packet.inner.clone()).unwrap();
Expand Down Expand Up @@ -365,9 +360,9 @@ mod test {
let socket2 = create_socket();
let mut packet = get_packet();
packet.inner.set_ack_id(1);
let socks = vec![socket.clone().into(), socket2.clone().into()];
let socks = vec![&socket, &socket2];
let stream: AckStream<String, LocalAdapter> =
AckInnerStream::broadcast(packet, socks, None).into();
AckInnerStream::broadcast(packet, socks.into_iter(), TIMEOUT).into();

let res_packet = Packet::ack("test", value(132), 1);
socket.recv(res_packet.inner.clone()).unwrap();
Expand Down Expand Up @@ -422,9 +417,9 @@ mod test {
let socket2 = create_socket();
let mut packet = get_packet();
packet.inner.set_ack_id(1);
let socks = vec![socket.clone().into(), socket2.clone().into()];
let socks = vec![&socket, &socket2];
let stream: AckStream<String, LocalAdapter> =
AckInnerStream::broadcast(packet, socks, None).into();
AckInnerStream::broadcast(packet, socks.into_iter(), TIMEOUT).into();

let res_packet = Packet::ack("test", value("test"), 1);
socket.clone().recv(res_packet.inner.clone()).unwrap();
Expand Down Expand Up @@ -478,9 +473,9 @@ mod test {
let socket2 = create_socket();
let mut packet = get_packet();
packet.inner.set_ack_id(1);
let socks = vec![socket.clone().into(), socket2.clone().into()];
let socks = vec![&socket, &socket2];
let stream: AckStream<String, LocalAdapter> =
AckInnerStream::broadcast(packet, socks, Some(Duration::from_millis(10))).into();
AckInnerStream::broadcast(packet, socks.into_iter(), Duration::from_millis(10)).into();

socket
.recv(Packet::ack("test", value("test"), 1).inner)
Expand Down
16 changes: 3 additions & 13 deletions crates/socketioxide/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ impl<A: Adapter> Client<A> {
// We have to create a new `Str` otherwise, we would keep a ref to the original connect packet
// for the entire lifetime of the Namespace.
let path = Str::copy_from_slice(&ns_path);
let ns = ns_ctr.get_new_ns(path.clone(), &self.adapter_state, self.config.parser);
let ns = ns_ctr.get_new_ns(path.clone(), &self.adapter_state, &self.config);
let this = self.clone();
let esocket = esocket.clone();
tokio::spawn(async move {
Expand Down Expand Up @@ -157,12 +157,7 @@ impl<A: Adapter> Client<A> {
tracing::debug!("adding namespace {}", path);

let ns_path = Str::from(&path);
let ns = Namespace::new(
ns_path.clone(),
callback,
&self.adapter_state,
self.config.parser,
);
let ns = Namespace::new(ns_path.clone(), callback, &self.adapter_state, &self.config);
// We spawn the adapter init task and therefore it might fail but the namespace is still added.
// The best solution would be to make the fn async and returning the error to the user.
// However this would require all .ns() calls to be async.
Expand Down Expand Up @@ -472,12 +467,7 @@ mod test {
#[tokio::test]
async fn get_ns() {
let client = create_client();
let ns = Namespace::new(
Str::from("/"),
|| {},
&client.adapter_state,
client.config.parser,
);
let ns = Namespace::new(Str::from("/"), || {}, &client.adapter_state, &client.config);
client.nsps.write().unwrap().insert(Str::from("/"), ns);
assert!(client.get_ns("/").is_some());
}
Expand Down
Loading

0 comments on commit 50a36fe

Please sign in to comment.