Skip to content

Commit

Permalink
update slot and controller
Browse files Browse the repository at this point in the history
  • Loading branch information
kariy committed Sep 13, 2024
1 parent d64fdae commit a516b5e
Show file tree
Hide file tree
Showing 24 changed files with 525 additions and 353 deletions.
484 changes: 280 additions & 204 deletions Cargo.lock

Large diffs are not rendered by default.

36 changes: 16 additions & 20 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ inherits = "release"
lto = "fat"

[workspace.dependencies]
cainome = { git = "/~https://github.com/cartridge-gg/cainome", tag = "v0.3.2", features = [ "abigen-rs" ] }
cainome = { git = "/~https://github.com/cartridge-gg/cainome", rev = "0d29bb0", features = [ "abigen-rs" ] }
dojo-utils = { path = "crates/dojo-utils" }

# metrics
Expand Down Expand Up @@ -117,7 +117,7 @@ sozo-walnut = { path = "crates/sozo/walnut" }
anyhow = "1.0.80"
assert_fs = "1.1"
assert_matches = "1.5.0"
async-trait = "0.1.77"
async-trait = "0.1.82"
auto_impl = "1.2.0"
base64 = "0.21.2"
bigdecimal = "0.4.1"
Expand Down Expand Up @@ -158,6 +158,7 @@ derive_more = "0.99.17"
flate2 = "1.0.24"
futures = "0.3.30"
futures-util = "0.3.30"
hashlink = "0.9.1"
hex = "0.4.3"
http = "0.2.9"
indexmap = "2.2.5"
Expand All @@ -184,18 +185,10 @@ scarb-ui = { git = "/~https://github.com/software-mansion/scarb", tag = "v2.7.0" }
semver = "1.0.5"
serde = { version = "1.0", features = [ "derive" ] }
serde_json = { version = "1.0", features = [ "arbitrary_precision" ] }
serde_with = "2.3"
serde_with = "3.9.0"
similar-asserts = "1.5.0"
smol_str = { version = "0.2.0", features = [ "serde" ] }
sqlx = { version = "0.7.2", features = [ "chrono", "macros", "regexp", "runtime-async-std", "runtime-tokio", "sqlite", "uuid" ] }
starknet = "0.11.0"
starknet-crypto = "0.7.0"
# `starknet-rs` is using `starknet-types-core` 0.1.3, but we need >=0.1.4 because
# we need this </~https://github.com/starknet-io/types-rs/pull/75>. So we put strict
# requirement here to prevent from being downgraded.
# We can remove this requirement once `starknet-rs` is using >=0.1.4
hashlink = "0.9.1"
starknet-types-core = "~0.1.4"
starknet_api = "0.11.0"
strum = "0.25"
strum_macros = "0.25"
Expand Down Expand Up @@ -233,9 +226,8 @@ alloy-sol-types = { version = "0.7.6", default-features = false }

criterion = "0.5.1"

# Controller integration
account_sdk = { git = "/~https://github.com/cartridge-gg/controller", rev = "512ff89" }
slot = { git = "/~https://github.com/cartridge-gg/slot", rev = "4c1165d" }
# Slot integration. Dojo don't need to manually include `account_sdk` as dependency as `slot` already re-exports it.
slot = { git = "/~https://github.com/cartridge-gg/slot", tag = "v0.14.0" }

alloy-contract = { version = "0.2", default-features = false }
alloy-json-rpc = { version = "0.2", default-features = false }
Expand All @@ -245,10 +237,14 @@ alloy-rpc-types-eth = { version = "0.2", default-features = false }
alloy-signer = { version = "0.2", default-features = false }
alloy-transport = { version = "0.2", default-features = false }

starknet = "0.11.0"
starknet-crypto = "0.7.1"
# `starknet-rs` is using `starknet-types-core` 0.1.3, but we need >=0.1.4 because
# we need this </~https://github.com/starknet-io/types-rs/pull/75>. So we put strict
# requirement here to prevent from being downgraded.
# We can remove this requirement once `starknet-rs` is using >=0.1.4
starknet-types-core = "~0.1.4"

[patch.crates-io]
# Remove this patch once the following PR is merged: </~https://github.com/xJonathanLEI/starknet-rs/pull/615>
#
# To enable std feature on `starknet-types-core`.
# To re-export the entire `felt` module from `starknet-types-core`.
starknet-core = { git = "/~https://github.com/kariy/starknet-rs", branch = "dojo-patch" }
starknet-types-core = { git = "/~https://github.com/dojoengine/types-rs", rev = "289e2f0" }
starknet = { git = "/~https://github.com/xJonathanLEI/starknet-rs", rev = "2ddc694" }
starknet-types-core = { git = "/~https://github.com/starknet-io/types-rs", rev = "f98f048" }
3 changes: 1 addition & 2 deletions bin/sozo/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ version.workspace = true
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
account_sdk = { workspace = true, optional = true }
slot = { workspace = true, optional = true }

anyhow.workspace = true
Expand Down Expand Up @@ -73,5 +72,5 @@ snapbox = "0.4.6"
[features]
default = [ "controller", "walnut" ]

controller = [ "dep:account_sdk", "dep:reqwest", "dep:slot" ]
controller = [ "dep:reqwest", "dep:slot" ]
walnut = [ "dep:sozo-walnut", "sozo-ops/walnut" ]
114 changes: 61 additions & 53 deletions bin/sozo/src/commands/options/account/controller.rs
Original file line number Diff line number Diff line change
@@ -1,33 +1,40 @@
use std::str::FromStr;
use std::sync::Arc;

use account_sdk::account::session::hash::{AllowedMethod, Session};
use account_sdk::account::session::SessionAccount;
use account_sdk::signers::HashSigner;
use anyhow::{bail, Context, Result};
use camino::{Utf8Path, Utf8PathBuf};
use dojo_utils::TransactionWaiter;
use dojo_world::contracts::naming::get_name_from_tag;
use dojo_world::manifest::{BaseManifest, Class, DojoContract, Manifest};
use dojo_world::migration::strategy::generate_salt;
use scarb::core::Config;
use slot::session::Policy;
use slot::account_sdk::account::session::hash::{AllowedMethod, ProvedMethod};
use slot::account_sdk::account::session::merkle::MerkleTree;
use slot::account_sdk::account::session::SessionAccount;
use slot::session::{FullSessionInfo, PolicyMethod};
use starknet::core::types::contract::{AbiEntry, StateMutability};
use starknet::core::types::StarknetError::ContractNotFound;
use starknet::core::types::{BlockId, BlockTag, Felt};
use starknet::core::utils::{cairo_short_string_to_felt, get_contract_address};
use starknet::macros::{felt, short_string};
use starknet::core::utils::{
cairo_short_string_to_felt, get_contract_address, get_selector_from_name,
};
use starknet::macros::felt;
use starknet::providers::Provider;
use starknet::providers::ProviderError::StarknetError;
use starknet::signers::SigningKey;
use starknet_crypto::poseidon_hash_single;
use tracing::{trace, warn};
use url::Url;

use super::WorldAddressOrName;

// Why the Arc? becaues the Controller account implementation over on `account_sdk` crate is
// riddled with `+ Clone` bounds on its Provider generic. So we explicitly specify that the Provider
// impl here is wrapped in an Arc to satisfy the Clone bound. Otherwise, you would get a 'trait
// bound not satisfied' error.
//
// This type comes from account_sdk, which doesn't derive Debug.
#[allow(missing_debug_implementations)]
pub type ControllerSessionAccount<P> = SessionAccount<P, SigningKey, SigningKey>;
pub type ControllerSessionAccount<P> = SessionAccount<Arc<P>>;

/// Create a new Catridge Controller account based on session key.
#[tracing::instrument(
Expand Down Expand Up @@ -59,31 +66,32 @@ where
"Creating Controller session account"
);

// make sure account exist on the provided chain, if not, we deploy it first before proceeding
deploy_account_if_not_exist(rpc_url.clone(), &provider, chain_id, contract_address, &username)
.await
.with_context(|| format!("Deploying Controller account on chain {chain_id}"))?;

// Check if the session exists, if not create a new one
let session_details = match slot::session::get(chain_id)? {
Some(session) => {
trace!(expires_at = %session.expires_at, policies = session.policies.len(), "Found existing session.");
trace!(expires_at = %session.session.expires_at, policies = session.session.allowed_methods.len(), "Found existing session.");

// Perform policies diff check. For security reasons, we will always create a new
// session here if the current policies are different from the existing
// session.
//
// TODO(kariy): maybe don't need to update if current policies is a
// subset of the existing policies.
let policies = collect_policies(world_addr_or_name, contract_address, config)?;
// check if the policies have changed
let is_equal = is_equal_to_existing(&policies, &session);

if policies != session.policies {
if is_equal {
session
} else {
trace!(
new_policies = policies.len(),
existing_policies = session.policies.len(),
existing_policies = session.session.allowed_methods.len(),
"Policies have changed. Creating new session."
);

let session = slot::session::create(rpc_url.clone(), &policies).await?;
slot::session::store(chain_id, &session)?;
session
} else {
session
}
}

Expand All @@ -97,35 +105,34 @@ where
}
};

let methods = session_details
.policies
.into_iter()
.map(|p| AllowedMethod::new(p.target, &p.method))
.collect::<Result<Vec<AllowedMethod>, _>>()?;
Ok(session_details.into_account(Arc::new(provider)))
}

// Copied from `account-wasm` </~https://github.com/cartridge-gg/controller/blob/0dd4dd6cbc5fcd3b9a1fd8d63dc127f6312b733f/packages/account-wasm/src/lib.rs#L78-L88>
let guardian = SigningKey::from_secret_scalar(short_string!("CARTRIDGE_GUARDIAN"));
let signer = SigningKey::from_secret_scalar(session_details.credentials.private_key);
// TODO(kariy): make `expires_at` a `u64` type in the session struct
let expires_at = session_details.expires_at.parse::<u64>()?;
let session = Session::new(methods, expires_at, &signer.signer())?;
// Check if the new policies are equal to the ones in the existing session
//
// This function would compute the merkle root of the new policies and compare it with the root in
// the existing SessionMetadata.
fn is_equal_to_existing(new_policies: &[PolicyMethod], session_info: &FullSessionInfo) -> bool {
let allowed_methods = new_policies
.iter()
.map(|p| AllowedMethod::new(p.target, get_selector_from_name(&p.method).unwrap()))
.collect::<Vec<AllowedMethod>>();

// make sure account exist on the provided chain, if not, we deploy it first before proceeding
deploy_account_if_not_exist(rpc_url, &provider, chain_id, contract_address, &username)
.await
.with_context(|| format!("Deploying Controller account on chain {chain_id}"))?;
// Copied from somewhere
let hashes = allowed_methods.iter().map(AllowedMethod::as_merkle_leaf).collect::<Vec<Felt>>();

let session_account = SessionAccount::new(
provider,
signer,
guardian,
contract_address,
chain_id,
session_details.credentials.authorization,
session,
);
let allowed_methods = allowed_methods
.into_iter()
.enumerate()
.map(|(i, method)| ProvedMethod {
method,
proof: MerkleTree::compute_proof(hashes.clone(), i),
})
.collect::<Vec<ProvedMethod>>();

let root = MerkleTree::compute_root(hashes[0], allowed_methods[0].proof.clone());

Ok(session_account)
root == session_info.session.allowed_methods_root
}

/// Policies are the building block of a session key. It's what defines what methods are allowed for
Expand All @@ -137,7 +144,7 @@ fn collect_policies(
world_addr_or_name: WorldAddressOrName,
user_address: Felt,
config: &Config,
) -> Result<Vec<Policy>> {
) -> Result<Vec<PolicyMethod>> {
let root_dir = config.root();
let manifest = get_project_base_manifest(root_dir, config.profile().as_str())?;
let policies =
Expand All @@ -157,8 +164,8 @@ fn collect_policies_from_base_manifest(
user_address: Felt,
base_path: &Utf8Path,
manifest: BaseManifest,
) -> Result<Vec<Policy>> {
let mut policies: Vec<Policy> = Vec::new();
) -> Result<Vec<PolicyMethod>> {
let mut policies: Vec<PolicyMethod> = Vec::new();
let base_path: Utf8PathBuf = base_path.to_path_buf();

// compute the world address here if it's a name
Expand All @@ -180,14 +187,14 @@ fn collect_policies_from_base_manifest(
// special policy for sending declare tx
// corresponds to [account_sdk::account::DECLARATION_SELECTOR]
let method = "__declare_transaction__".to_string();
policies.push(Policy { target: user_address, method });
policies.push(PolicyMethod { target: user_address, method });
trace!("Adding declare transaction policy");

// for deploying using udc
let method = "deployContract".to_string();
const UDC_ADDRESS: Felt =
felt!("0x041a78e741e5af2fec34b695679bc6891742439f7afb8484ecd7766661ad02bf");
policies.push(Policy { target: UDC_ADDRESS, method });
policies.push(PolicyMethod { target: UDC_ADDRESS, method });
trace!("Adding UDC deployment policy");

Ok(policies)
Expand All @@ -196,7 +203,7 @@ fn collect_policies_from_base_manifest(
/// Recursively extract methods and convert them into policies from the all the
/// ABIs in the project.
fn policies_from_abis(
policies: &mut Vec<Policy>,
policies: &mut Vec<PolicyMethod>,
contract_tag: &str,
contract_address: Felt,
entries: &[AbiEntry],
Expand All @@ -206,7 +213,8 @@ fn policies_from_abis(
AbiEntry::Function(f) => {
// we only create policies for non-view functions
if let StateMutability::External = f.state_mutability {
let policy = Policy { target: contract_address, method: f.name.to_string() };
let policy =
PolicyMethod { target: contract_address, method: f.name.to_string() };
trace!(tag = contract_tag, target = format!("{:#x}", policy.target), method = %policy.method, "Adding policy");
policies.push(policy);
}
Expand Down Expand Up @@ -357,7 +365,7 @@ mod tests {
use scarb::compiler::Profile;
use starknet::macros::felt;

use super::{collect_policies, Policy};
use super::{collect_policies, PolicyMethod};
use crate::commands::options::account::WorldAddressOrName;

#[test]
Expand All @@ -373,7 +381,7 @@ mod tests {

// Get test data
let test_data = include_str!("../../../../tests/test_data/policies.json");
let expected_policies: Vec<Policy> = serde_json::from_str(test_data).unwrap();
let expected_policies: Vec<PolicyMethod> = serde_json::from_str(test_data).unwrap();

// Compare the collected policies with the test data
assert_eq!(policies.len(), expected_policies.len());
Expand Down
7 changes: 4 additions & 3 deletions bin/sozo/src/commands/options/account/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,11 @@ impl AccountOptions {
if self.controller {
let url = starknet.url(env_metadata)?;
let account = self.controller(url, provider, world_address_or_name, config).await?;
return Ok(SozoAccount::from(account));
return Ok(SozoAccount::Controller(account));
}

let account = self.std_account(provider, env_metadata).await?;
Ok(SozoAccount::from(account))
Ok(SozoAccount::Standard(account))
}

pub async fn std_account<P>(
Expand Down Expand Up @@ -151,7 +151,8 @@ impl AccountOptions {
#[cfg(test)]
mod tests {
use clap::Parser;
use starknet::accounts::{Call, ExecutionEncoder};
use starknet::accounts::ExecutionEncoder;
use starknet::core::types::Call;
use starknet_crypto::Felt;

use super::{AccountOptions, DOJO_ACCOUNT_ADDRESS_ENV_VAR};
Expand Down
Loading

0 comments on commit a516b5e

Please sign in to comment.