Skip to content

Commit

Permalink
feat(torii-client): add entity change listener (dojoengine#1072)
Browse files Browse the repository at this point in the history
  • Loading branch information
kariy authored and Mateusz Zając committed Oct 21, 2023
1 parent 6130f04 commit 50194a9
Show file tree
Hide file tree
Showing 15 changed files with 507 additions and 199 deletions.
4 changes: 3 additions & 1 deletion crates/torii/client/src/client/error.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use dojo_world::contracts::model::ModelError;
use starknet::core::utils::CairoShortStringToFeltError;
use starknet::core::utils::{CairoShortStringToFeltError, ParseCairoShortStringError};
use starknet::providers::jsonrpc::HttpTransport;
use starknet::providers::{JsonRpcClient, Provider};

Expand Down Expand Up @@ -30,4 +30,6 @@ pub enum ParseError {
FeltFromStr(#[from] starknet::core::types::FromStrError),
#[error(transparent)]
CairoShortStringToFelt(#[from] CairoShortStringToFeltError),
#[error(transparent)]
ParseCairoShortString(#[from] ParseCairoShortStringError),
}
41 changes: 23 additions & 18 deletions crates/torii/client/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,26 @@ pub mod storage;
pub mod subscription;

use std::cell::OnceCell;
use std::collections::HashSet;
use std::sync::Arc;

use dojo_types::packing::unpack;
use dojo_types::schema::{EntityModel, Ty};
use dojo_types::WorldMetadata;
use dojo_world::contracts::WorldContractReader;
use parking_lot::RwLock;
use parking_lot::{RwLock, RwLockReadGuard};
use starknet::core::utils::cairo_short_string_to_felt;
use starknet::providers::jsonrpc::HttpTransport;
use starknet::providers::JsonRpcClient;
use starknet_crypto::FieldElement;
use tokio::sync::RwLock as AsyncRwLock;
use torii_grpc::protos::world::SubscribeEntitiesResponse;
use torii_grpc::client::EntityUpdateStreaming;

use self::error::{Error, ParseError};
use self::storage::ModelStorage;
use self::subscription::{SubscribedEntities, SubscriptionClientHandle};
use crate::client::subscription::SubscriptionService;

// TODO: expose the World interface from the `Client`
// TODO: remove reliance on RPC
#[allow(unused)]
pub struct Client {
Expand All @@ -46,9 +46,13 @@ impl Client {
ClientBuilder::new()
}

/// Returns the metadata of the world that the client is connected to.
pub fn metadata(&self) -> WorldMetadata {
self.metadata.read().clone()
/// Returns a read lock on the World metadata that the client is connected to.
pub fn metadata(&self) -> RwLockReadGuard<'_, WorldMetadata> {
self.metadata.read()
}

pub fn subscribed_entities(&self) -> RwLockReadGuard<'_, HashSet<EntityModel>> {
self.subscribed_entities.entities.read()
}

/// Returns the model value of an entity.
Expand All @@ -59,7 +63,7 @@ impl Client {
let mut schema = self.metadata.read().model(model).map(|m| m.schema.clone())?;

let Ok(Some(raw_values)) =
self.storage.get_entity(cairo_short_string_to_felt(model).ok()?, keys)
self.storage.get_entity_storage(cairo_short_string_to_felt(model).ok()?, keys)
else {
return Some(schema);
};
Expand All @@ -79,15 +83,10 @@ impl Client {
Some(schema)
}

/// Returns the list of entities that the client is subscribed to.
pub fn synced_entities(&self) -> Vec<EntityModel> {
self.subscribed_entities.entities.read().clone().into_iter().collect()
}

/// Initiate the entity subscriptions and returns a [SubscriptionService] which when await'ed
/// will execute the subscription service and starts the syncing process.
pub async fn start_subscription(&self) -> Result<SubscriptionService, Error> {
let entities = self.synced_entities();
let entities = self.subscribed_entities.entities.read().clone().into_iter().collect();
let sub_res_stream = self.initiate_subscription(entities).await?;

let (service, handle) = SubscriptionService::new(
Expand All @@ -111,7 +110,8 @@ impl Client {

self.subscribed_entities.add_entities(entities)?;

let updated_entities = self.synced_entities();
let updated_entities =
self.subscribed_entities.entities.read().clone().into_iter().collect();
let sub_res_stream = self.initiate_subscription(updated_entities).await?;

match self.sub_client_handle.get() {
Expand All @@ -127,7 +127,8 @@ impl Client {
pub async fn remove_entities_to_sync(&self, entities: Vec<EntityModel>) -> Result<(), Error> {
self.subscribed_entities.remove_entities(entities)?;

let updated_entities = self.synced_entities();
let updated_entities =
self.subscribed_entities.entities.read().clone().into_iter().collect();
let sub_res_stream = self.initiate_subscription(updated_entities).await?;

match self.sub_client_handle.get() {
Expand All @@ -137,10 +138,14 @@ impl Client {
Ok(())
}

pub fn storage(&self) -> Arc<ModelStorage> {
Arc::clone(&self.storage)
}

async fn initiate_subscription(
&self,
entities: Vec<EntityModel>,
) -> Result<tonic::Streaming<SubscribeEntitiesResponse>, Error> {
) -> Result<EntityUpdateStreaming, Error> {
let mut grpc_client = self.inner.write().await;
let stream = grpc_client.subscribe_entities(entities).await?;
Ok(stream)
Expand All @@ -149,7 +154,7 @@ impl Client {
async fn initiate_entity(&self, model: &str, keys: Vec<FieldElement>) -> Result<(), Error> {
let model_reader = self.world_reader.model(model).await?;
let values = model_reader.entity_storage(&keys).await?;
self.storage.set_entity(
self.storage.set_entity_storage(
cairo_short_string_to_felt(model).map_err(ParseError::CairoShortStringToFelt)?,
keys,
values,
Expand Down Expand Up @@ -207,7 +212,7 @@ impl ClientBuilder {
let model_reader = world_reader.model(&model).await?;
let values = model_reader.entity_storage(&keys).await?;

client_storage.set_entity(
client_storage.set_entity_storage(
cairo_short_string_to_felt(&model).unwrap(),
keys,
values,
Expand Down
126 changes: 99 additions & 27 deletions crates/torii/client/src/client/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ use std::collections::{HashMap, HashSet};
use std::sync::Arc;

use dojo_types::WorldMetadata;
use parking_lot::RwLock;
use futures::channel::mpsc::{channel, Receiver, Sender};
use parking_lot::{Mutex, RwLock};
use starknet::core::utils::parse_cairo_short_string;
use starknet_crypto::FieldElement;

use super::error::Error;
use super::error::{Error, ParseError};
use crate::utils::compute_all_storage_addresses;

pub type EntityKeys = Vec<FieldElement>;
Expand All @@ -17,25 +18,77 @@ pub type StorageValue = FieldElement;

/// An in-memory storage for storing the component values of entities.
// TODO: check if we can use sql db instead.
pub(crate) struct ModelStorage {
pub struct ModelStorage {
metadata: Arc<RwLock<WorldMetadata>>,
pub(crate) storage: RwLock<HashMap<StorageKey, StorageValue>>,
storage: RwLock<HashMap<StorageKey, StorageValue>>,
// a map of model name to a set of entity keys.
model_index: RwLock<HashMap<FieldElement, HashSet<EntityKeys>>>,

// listener for storage updates.
senders: Mutex<HashMap<u8, Sender<()>>>,
listeners: Mutex<HashMap<StorageKey, Vec<u8>>>,
}

impl ModelStorage {
pub(super) fn new(metadata: Arc<RwLock<WorldMetadata>>) -> Self {
Self { metadata, storage: Default::default(), model_index: Default::default() }
Self {
metadata,
storage: Default::default(),
model_index: Default::default(),
senders: Default::default(),
listeners: Default::default(),
}
}

/// Listen to entity changes.
///
/// # Arguments
/// * `model` - the model name.
/// * `keys` - the keys of the entity.
///
/// # Returns
/// A receiver that will receive updates for the specified storage keys.
pub fn add_listener(
&self,
model: FieldElement,
keys: &[FieldElement],
) -> Result<Receiver<()>, Error> {
let storage_addresses = self.get_entity_storage_addresses(model, keys)?;

let (sender, receiver) = channel(128);
let listener_id = self.senders.lock().len() as u8;
self.senders.lock().insert(listener_id, sender);

storage_addresses.iter().for_each(|key| {
self.listeners.lock().entry(*key).or_default().push(listener_id);
});

Ok(receiver)
}

#[allow(unused)]
pub(super) fn set_entity(
/// Retrieves the raw values of an entity.
pub fn get_entity_storage(
&self,
model: FieldElement,
raw_keys: &[FieldElement],
) -> Result<Option<Vec<FieldElement>>, Error> {
let storage_addresses = self.get_entity_storage_addresses(model, raw_keys)?;
Ok(storage_addresses
.into_iter()
.map(|storage_address| self.storage.read().get(&storage_address).copied())
.collect::<Option<Vec<_>>>())
}

/// Set the raw values of an entity.
pub fn set_entity_storage(
&self,
model: FieldElement,
raw_keys: Vec<FieldElement>,
raw_values: Vec<FieldElement>,
) -> Result<(), Error> {
let model_name = parse_cairo_short_string(&model).expect("valid cairo short string");
let model_name =
parse_cairo_short_string(&model).map_err(ParseError::ParseCairoShortString)?;

let model_packed_size = self
.metadata
.read()
Expand All @@ -55,35 +108,54 @@ impl ModelStorage {
Ordering::Equal => {}
}

let storage_addresses = compute_all_storage_addresses(model, &raw_keys, model_packed_size);
storage_addresses.iter().zip(&raw_values).for_each(|(storage_address, value)| {
self.storage.write().insert(*storage_address, *value);
});
let storage_addresses = self.get_entity_storage_addresses(model, &raw_keys)?;
self.set_storages_at(storage_addresses.into_iter().zip(raw_values).collect());
self.index_entity(model, raw_keys);

Ok(())
}

pub(super) fn get_entity(
/// Set the value of storage slots in bulk
pub(super) fn set_storages_at(&self, storage_entries: Vec<(FieldElement, FieldElement)>) {
let mut senders: HashSet<u8> = Default::default();

for (key, _) in &storage_entries {
if let Some(lists) = self.listeners.lock().get(key) {
for id in lists {
senders.insert(*id);
}
}
}

self.storage.write().extend(storage_entries);

for sender_id in senders {
self.notify_listener(sender_id);
}
}

fn notify_listener(&self, id: u8) {
if let Some(sender) = self.senders.lock().get_mut(&id) {
let _ = sender.try_send(());
}
}

fn get_entity_storage_addresses(
&self,
model: FieldElement,
raw_keys: &[FieldElement],
) -> Result<Option<Vec<FieldElement>>, Error> {
let model_name = parse_cairo_short_string(&model).expect("valid cairo short string");
) -> Result<Vec<FieldElement>, Error> {
let model_name =
parse_cairo_short_string(&model).map_err(ParseError::ParseCairoShortString)?;

let model_packed_size = self
.metadata
.read()
.model(&parse_cairo_short_string(&model).expect("valid cairo short string"))
.model(&model_name)
.map(|c| c.packed_size)
.ok_or(Error::UnknownModel(model_name))?;

let storage_addresses = compute_all_storage_addresses(model, raw_keys, model_packed_size);
let values = storage_addresses
.into_iter()
.map(|storage_address| self.storage.read().get(&storage_address).copied())
.collect::<Option<Vec<_>>>();

Ok(values)
Ok(compute_all_storage_addresses(model, raw_keys, model_packed_size))
}

fn index_entity(&self, model: FieldElement, raw_keys: Vec<FieldElement>) {
Expand Down Expand Up @@ -136,7 +208,7 @@ mod tests {

let values = vec![felt!("1"), felt!("2"), felt!("3"), felt!("4"), felt!("5")];
let model = cairo_short_string_to_felt(&entity.model).unwrap();
let result = storage.set_entity(model, entity.keys, values);
let result = storage.set_entity_storage(model, entity.keys, values);

assert!(storage.storage.read().is_empty());
matches!(
Expand All @@ -155,7 +227,7 @@ mod tests {

let values = vec![felt!("1"), felt!("2")];
let model = cairo_short_string_to_felt(&entity.model).unwrap();
let result = storage.set_entity(model, entity.keys, values);
let result = storage.set_entity_storage(model, entity.keys, values);

assert!(storage.storage.read().is_empty());
matches!(
Expand Down Expand Up @@ -186,11 +258,11 @@ mod tests {
let model_name_in_felt = cairo_short_string_to_felt(&entity.model).unwrap();

storage
.set_entity(model_name_in_felt, entity.keys.clone(), expected_values.clone())
.set_entity_storage(model_name_in_felt, entity.keys.clone(), expected_values.clone())
.expect("set storage values");

let actual_values = storage
.get_entity(model_name_in_felt, &entity.keys)
.get_entity_storage(model_name_in_felt, &entity.keys)
.expect("model exist")
.expect("values are set");

Expand Down
Loading

0 comments on commit 50194a9

Please sign in to comment.