Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(socketio/ns): improve SocketEmitter trait #410

Merged
merged 3 commits into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 163 additions & 66 deletions crates/socketioxide-core/src/adapter.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
//! The adapter module contains the [`CoreAdapter`] trait and other related types.
//!
//! It is used to implement communication between socket.io servers to share messages and state.
//!
//! The [`CoreLocalAdapter`] provide a local implementation that will allow any implementors to apply local
//! operations (`broadcast_with_ack`, `broadcast`, `rooms`, etc...).
use std::{
borrow::Cow,
collections::{HashMap, HashSet},
collections::{hash_set, HashMap, HashSet},
error::Error as StdError,
future::{self, Future},
slice,
sync::{Arc, RwLock},
time::Duration,
};
Expand Down Expand Up @@ -169,18 +173,18 @@ pub trait SocketEmitter: Send + Sync + 'static {
type AckStream: Stream<Item = AckStreamItem<Self::AckError>> + FusedStream + Send + 'static;

/// Get all the socket ids in the namespace.
fn get_all_sids(&self) -> Vec<Sid>;
fn get_all_sids(&self, filter: impl Fn(&Sid) -> bool) -> Vec<Sid>;
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
/// Send data to the list of socket ids.
fn send_many(&self, sids: Vec<Sid>, data: Value) -> Result<(), Vec<SocketError>>;
/// Send data to the list of socket ids and get a stream of acks.
fn send_many(&self, sids: BroadcastIter<'_>, data: Value) -> Result<(), Vec<SocketError>>;
/// Send data to the list of socket ids and get a stream of acks and the number of expected acks.
fn send_many_with_ack(
&self,
sids: Vec<Sid>,
sids: BroadcastIter<'_>,
packet: Packet,
timeout: Option<Duration>,
) -> Self::AckStream;
) -> (Self::AckStream, u32);
/// Disconnect all the sockets in the list.
fn disconnect_many(&self, sid: Vec<Sid>) -> Result<(), Vec<SocketError>>;
fn disconnect_many(&self, sids: BroadcastIter<'_>) -> Result<(), Vec<SocketError>>;
/// Get the path of the namespace.
fn path(&self) -> &Str;
/// Get the parser of the namespace.
Expand Down Expand Up @@ -364,10 +368,11 @@ impl<E: SocketEmitter> CoreLocalAdapter<E> {
opts: BroadcastOptions,
) -> Result<(), Vec<SocketError>> {
use crate::parser::Parse;
let sids = self.apply_opts(opts);
let room_map = self.rooms.read().unwrap();
let sids = self.apply_opts(&opts, &room_map);

#[cfg(feature = "tracing")]
tracing::debug!("broadcasting packet to {} sockets", sids.len());
tracing::debug!("broadcasting packet");
if sids.is_empty() {
return Ok(());
}
Expand All @@ -384,19 +389,19 @@ impl<E: SocketEmitter> CoreLocalAdapter<E> {
opts: BroadcastOptions,
timeout: Option<Duration>,
) -> (E::AckStream, u32) {
let sids = self.apply_opts(opts);
let room_map = self.rooms.read().unwrap();
let sids = self.apply_opts(&opts, &room_map);
#[cfg(feature = "tracing")]
tracing::debug!("broadcasting packet to {} sockets: {:?}", sids.len(), sids);
tracing::debug!("broadcasting packet");

let count = sids.len() as u32;
// We cannot pre-serialize the packet because we need to change the ack id.
let stream = self.sockets.send_many_with_ack(sids, packet, timeout);
(stream, count)
self.sockets.send_many_with_ack(sids, packet, timeout)
}

/// Returns the sockets ids that match the [`BroadcastOptions`].
pub fn sockets(&self, opts: BroadcastOptions) -> Vec<Sid> {
self.apply_opts(opts)
self.apply_opts(&opts, &self.rooms.read().unwrap())
.collect()
}

//TODO: make this operation O(1)
Expand All @@ -413,22 +418,31 @@ impl<E: SocketEmitter> CoreLocalAdapter<E> {
/// Adds the sockets that match the [`BroadcastOptions`] to the rooms.
pub fn add_sockets(&self, opts: BroadcastOptions, rooms: impl RoomParam) {
let rooms: Vec<Room> = rooms.into_room_iter().collect();
for sid in self.apply_opts(opts) {
// Here we have to collect sids, because we are going to modify the rooms map.
let sids = self
.apply_opts(&opts, &self.rooms.read().unwrap())
.collect::<Vec<_>>();
for sid in sids {
self.add_all(sid, rooms.clone());
}
}

/// Removes the sockets that match the [`BroadcastOptions`] from the rooms.
pub fn del_sockets(&self, opts: BroadcastOptions, rooms: impl RoomParam) {
let rooms: Vec<Room> = rooms.into_room_iter().collect();
for sid in self.apply_opts(opts) {
// Here we have to collect sids, because we are going to modify the rooms map.
let sids = self
.apply_opts(&opts, &self.rooms.read().unwrap())
.collect::<Vec<_>>();
for sid in sids {
self.del(sid, rooms.clone());
}
}

/// Disconnects the sockets that match the [`BroadcastOptions`].
pub fn disconnect_socket(&self, opts: BroadcastOptions) -> Result<(), Vec<SocketError>> {
let sids = self.apply_opts(opts);
let room_map = self.rooms.read().unwrap();
let sids = self.apply_opts(&opts, &room_map);
self.sockets.disconnect_many(sids)
}

Expand All @@ -448,35 +462,70 @@ impl<E: SocketEmitter> CoreLocalAdapter<E> {
}
}

/// The default broadcast iterator.
/// Extract, flatten and filter a list of sid from a room list
struct BroadcastRooms<'a> {
rooms: slice::Iter<'a, Room>,
rooms_map: &'a HashMap<Room, HashSet<Sid>>,
except: HashSet<Sid>,
flatten_iter: Option<hash_set::Iter<'a, Sid>>,
}
impl<'a> BroadcastRooms<'a> {
fn new(
rooms: &'a [Room],
rooms_map: &'a HashMap<Room, HashSet<Sid>>,
except: HashSet<Sid>,
) -> Self {
BroadcastRooms {
rooms: rooms.iter(),
rooms_map,
except,
flatten_iter: None,
}
}
}
impl Iterator for BroadcastRooms<'_> {
type Item = Sid;
fn next(&mut self) -> Option<Self::Item> {
loop {
match self.flatten_iter.as_mut().and_then(Iterator::next) {
Some(sid) if !self.except.contains(sid) => return Some(*sid),
Some(_) => continue,
None => self.flatten_iter = None,
}

let room = self.rooms.next()?;
self.flatten_iter = self.rooms_map.get(room).map(HashSet::iter);
}
}
}

impl<E: SocketEmitter> CoreLocalAdapter<E> {
/// Applies the given `opts` and return the sockets that match.
fn apply_opts(&self, opts: BroadcastOptions) -> Vec<Sid> {
fn apply_opts<'a>(
&self,
opts: &'a BroadcastOptions,
rooms: &'a HashMap<Room, HashSet<Sid>>,
) -> BroadcastIter<'a> {
let is_broadcast = opts.has_flag(BroadcastFlags::Broadcast);
let rooms = opts.rooms;

let except = self.get_except_sids(&opts.except);
if !rooms.is_empty() {
let rooms_map = self.rooms.read().unwrap();
rooms
.iter()
.filter_map(|room| rooms_map.get(room))
.flatten()
.copied()
.filter(|id| {
!except.contains(id)
&& (!is_broadcast || opts.sid.map(|s| &s != id).unwrap_or(true))
})
.collect()
let mut except = self.get_except_sids(&opts.except);
// In case of broadcast flag + if the sender is set,
// we should not broadcast to it.
if is_broadcast && opts.sid.is_some() {
except.insert(opts.sid.unwrap());
}

if !opts.rooms.is_empty() {
let iter = BroadcastRooms::new(&opts.rooms, rooms, except);
InnerBroadcastIter::BroadcastRooms(iter).into()
} else if is_broadcast {
self.sockets
.get_all_sids()
.into_iter()
.filter(|id| !except.contains(id) && opts.sid.map(|s| &s != id).unwrap_or(true))
.collect()
let sids = self.sockets.get_all_sids(|id| !except.contains(id));
InnerBroadcastIter::GlobalBroadcast(sids.into_iter()).into()
} else if let Some(id) = opts.sid {
vec![id]
InnerBroadcastIter::Single(id).into()
} else {
vec![]
InnerBroadcastIter::None.into()
}
}

Expand All @@ -492,11 +541,59 @@ impl<E: SocketEmitter> CoreLocalAdapter<E> {
}
}

/// An iterator that yields the socket ids that match the broadcast options.
/// Used with the [`SocketEmitter`] interface.
pub struct BroadcastIter<'a> {
inner: InnerBroadcastIter<'a>,
}
enum InnerBroadcastIter<'a> {
BroadcastRooms(BroadcastRooms<'a>),
GlobalBroadcast(<Vec<Sid> as IntoIterator>::IntoIter),
Single(Sid),
None,
}
impl BroadcastIter<'_> {
fn is_empty(&self) -> bool {
matches!(self.inner, InnerBroadcastIter::None)
}
}
impl<'a> From<InnerBroadcastIter<'a>> for BroadcastIter<'a> {
fn from(inner: InnerBroadcastIter<'a>) -> Self {
BroadcastIter { inner }
}
}

impl Iterator for BroadcastIter<'_> {
type Item = Sid;

#[inline(always)]
fn next(&mut self) -> Option<Self::Item> {
self.inner.next()
}
}
impl Iterator for InnerBroadcastIter<'_> {
type Item = Sid;

fn next(&mut self) -> Option<Self::Item> {
match self {
InnerBroadcastIter::BroadcastRooms(inner) => inner.next(),
InnerBroadcastIter::GlobalBroadcast(inner) => inner.next(),
InnerBroadcastIter::Single(sid) => {
let sid = *sid;
*self = InnerBroadcastIter::None;
Some(sid)
}
InnerBroadcastIter::None => None,
}
}
}

#[cfg(test)]
mod test {

use smallvec::smallvec;
use std::{
array,
pin::Pin,
task::{Context, Poll},
};
Expand Down Expand Up @@ -542,24 +639,28 @@ mod test {
type AckError = StubError;
type AckStream = StubAckStream;

fn get_all_sids(&self) -> Vec<Sid> {
self.sockets.iter().copied().collect()
fn get_all_sids(&self, filter: impl Fn(&Sid) -> bool) -> Vec<Sid> {
self.sockets
.iter()
.copied()
.filter(|id| filter(id))
.collect()
}

fn send_many(&self, _: Vec<Sid>, _: Value) -> Result<(), Vec<SocketError>> {
fn send_many(&self, _: BroadcastIter<'_>, _: Value) -> Result<(), Vec<SocketError>> {
Ok(())
}

fn send_many_with_ack(
&self,
_: Vec<Sid>,
_: BroadcastIter<'_>,
_: Packet,
_: Option<Duration>,
) -> Self::AckStream {
StubAckStream
) -> (Self::AckStream, u32) {
(StubAckStream, 0)
}

fn disconnect_many(&self, _: Vec<Sid>) -> Result<(), Vec<SocketError>> {
fn disconnect_many(&self, _: BroadcastIter<'_>) -> Result<(), Vec<SocketError>> {
Ok(())
}

Expand Down Expand Up @@ -738,43 +839,39 @@ mod test {
}
#[test]
fn test_apply_opts() {
let socket0 = Sid::new();
let socket1 = Sid::new();
let socket2 = Sid::new();
let adapter = create_adapter([socket0, socket1, socket2]);
let mut sockets: [Sid; 3] = array::from_fn(|_| Sid::new());
sockets.sort();
let adapter = create_adapter(sockets);
// Add socket 0 to room1 and room2
adapter.add_all(socket0, ["room1", "room2"]);
adapter.add_all(sockets[0], ["room1", "room2"]);
// Add socket 1 to room1 and room3
adapter.add_all(socket1, ["room1", "room3"]);
adapter.add_all(sockets[1], ["room1", "room3"]);
// Add socket 2 to room2 and room3
adapter.add_all(socket2, ["room1", "room2", "room3"]);
adapter.add_all(sockets[2], ["room1", "room2", "room3"]);

// socket 2 is the sender
let mut opts = BroadcastOptions::new(socket2);
let mut opts = BroadcastOptions::new(sockets[2]);
opts.rooms = smallvec!["room1".into()];
opts.except = smallvec!["room2".into()];
let sids = adapter.sockets(opts);
assert_eq!(sids.len(), 1);
assert_eq!(sids[0], socket1);
assert_eq!(sids, [sockets[1]]);

let mut opts = BroadcastOptions::new(socket2);
let mut opts = BroadcastOptions::new(sockets[2]);
opts.add_flag(BroadcastFlags::Broadcast);
let sids = adapter.sockets(opts);
assert_eq!(sids.len(), 2);
sids.into_iter().for_each(|id| {
assert!(id == socket0 || id == socket1);
});
let mut sids = adapter.sockets(opts);
sids.sort();
assert_eq!(sids, [sockets[0], sockets[1]]);

let mut opts = BroadcastOptions::new(socket2);
let mut opts = BroadcastOptions::new(sockets[2]);
opts.add_flag(BroadcastFlags::Broadcast);
opts.except = smallvec!["room2".into()];
let sids = adapter.sockets(opts);
assert_eq!(sids.len(), 1);

let opts = BroadcastOptions::new(socket2);
let opts = BroadcastOptions::new(sockets[2]);
let sids = adapter.sockets(opts);
assert_eq!(sids.len(), 1);
assert_eq!(sids[0], socket2);
assert_eq!(sids[0], sockets[2]);

let opts = BroadcastOptions::new(Sid::new());
let sids = adapter.sockets(opts);
Expand Down
Loading
Loading