Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Polkadot: Implement caller_is_root runtime API #1620

Merged
merged 11 commits into from
Feb 1, 2024
Merged
5 changes: 5 additions & 0 deletions docs/language/builtins.rst
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,11 @@ is_contract(address AccountId) returns (bool)

Only available on Polkadot. Checks whether the given address is a contract address.

caller_is_root() returns (bool)
+++++++++++++++++++++++++++++++

Only available on Polkadot. Returns true if the caller of the contract is `root <https://docs.substrate.io/build/origins/>`_.

set_code_hash(uint8[32] hash) returns (uint32)
++++++++++++++++++++++++++++++++++++++++++++++

Expand Down
13 changes: 13 additions & 0 deletions integration/polkadot/caller_is_root.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import "polkadot";

contract CallerIsRoot {
uint public balance;

function covert() public payable {
if (caller_is_root()) {
balance = 0xdeadbeef;
} else {
balance = 1;
}
}
}
42 changes: 42 additions & 0 deletions integration/polkadot/caller_is_root.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import expect from 'expect';
import { createConnection, deploy, aliceKeypair, query, weight, transaction } from './index';
import { ContractPromise } from '@polkadot/api-contract';
import { ApiPromise } from '@polkadot/api';
import { KeyringPair } from '@polkadot/keyring/types';

describe('Deploy the caller_is_root contract and test it', () => {
let conn: ApiPromise;
let contract: ContractPromise;
let alice: KeyringPair;

before(async function () {
conn = await createConnection();
alice = aliceKeypair();
const instance = await deploy(conn, alice, 'CallerIsRoot.contract', 0n);
contract = new ContractPromise(conn, instance.abi, instance.address);
});

after(async function () {
await conn.disconnect();
});

it('is correct on a non-root caller', async function () {
// Without sudo the caller should not be root
const gasLimit = await weight(conn, contract, "covert");
await transaction(contract.tx.covert({ gasLimit }), alice);

// Calling `covert` as non-root sets the balance to 1
const balance = await query(conn, alice, contract, "balance", []);
expect(BigInt(balance.output?.toString() ?? "")).toStrictEqual(1n);
});

it('is correct on a root caller', async function () {
// Alice has sudo rights on --dev nodes
const gasLimit = await weight(conn, contract, "covert");
await transaction(conn.tx.sudo.sudo(contract.tx.covert({ gasLimit })), alice);

// Calling `covert` as root sets the balance to 0xdeadbeef
const balance = await query(conn, alice, contract, "balance", []);
expect(BigInt(balance.output?.toString() ?? "")).toStrictEqual(0xdeadbeefn);
});
});
12 changes: 8 additions & 4 deletions src/emit/instructions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -519,10 +519,14 @@ pub(super) fn process_instruction<'a, T: TargetRuntime<'a> + ?Sized>(
}
}

let first_arg_type = bin.llvm_type(&args[0].ty(), ns);
if let Some(ret) =
target.builtin_function(bin, function, callee, &parms, first_arg_type, ns)
{
if let Some(ret) = target.builtin_function(
bin,
function,
callee,
&parms,
args.first().map(|arg| bin.llvm_type(&arg.ty(), ns)),
ns,
) {
let success = bin.builder.build_int_compare(
IntPredicate::EQ,
ret.into_int_value(),
Expand Down
2 changes: 1 addition & 1 deletion src/emit/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ pub trait TargetRuntime<'a> {
function: FunctionValue<'a>,
builtin_func: &Function,
args: &[BasicMetadataValueEnum<'a>],
first_arg_type: BasicTypeEnum,
first_arg_type: Option<BasicTypeEnum>,
ns: &Namespace,
) -> Option<BasicValueEnum<'a>>;

Expand Down
2 changes: 2 additions & 0 deletions src/emit/polkadot/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ impl PolkadotTarget {
"transfer",
"is_contract",
"set_code_hash",
"caller_is_root",
]);

binary
Expand Down Expand Up @@ -266,6 +267,7 @@ impl PolkadotTarget {
external!("deposit_event", void_type, u8_ptr, u32_val, u8_ptr, u32_val);
external!("is_contract", i32_type, u8_ptr);
external!("set_code_hash", i32_type, u8_ptr);
external!("caller_is_root", i32_type,);
}

/// Emits the "deploy" function if `storage_initializer` is `Some`, otherwise emits the "call" function.
Expand Down
13 changes: 12 additions & 1 deletion src/emit/polkadot/target.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1501,7 +1501,7 @@ impl<'a> TargetRuntime<'a> for PolkadotTarget {
_function: FunctionValue<'a>,
builtin_func: &Function,
args: &[BasicMetadataValueEnum<'a>],
_first_arg_type: BasicTypeEnum,
_first_arg_type: Option<BasicTypeEnum>,
ns: &Namespace,
) -> Option<BasicValueEnum<'a>> {
emit_context!(binary);
Expand Down Expand Up @@ -1579,6 +1579,17 @@ impl<'a> TargetRuntime<'a> for PolkadotTarget {
.build_store(args[1].into_pointer_value(), ret);
None
}
"caller_is_root" => {
let is_root = call!("caller_is_root", &[], "seal_caller_is_root")
.try_as_basic_value()
.left()
.unwrap()
.into_int_value();
binary
.builder
.build_store(args[0].into_pointer_value(), is_root);
None
}
_ => unimplemented!(),
}
}
Expand Down
5 changes: 4 additions & 1 deletion src/emit/solana/target.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1251,9 +1251,12 @@ impl<'a> TargetRuntime<'a> for SolanaTarget {
function: FunctionValue<'a>,
builtin_func: &ast::Function,
args: &[BasicMetadataValueEnum<'a>],
first_arg_type: BasicTypeEnum,
first_arg_type: Option<BasicTypeEnum>,
ns: &ast::Namespace,
) -> Option<BasicValueEnum<'a>> {
let first_arg_type =
first_arg_type.expect("solana does not have builtin without any parameter");

if builtin_func.id.name == "create_program_address" {
let func = binary
.module
Expand Down
2 changes: 1 addition & 1 deletion src/emit/soroban/target.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ impl<'a> TargetRuntime<'a> for SorobanTarget {
function: FunctionValue<'a>,
builtin_func: &Function,
args: &[BasicMetadataValueEnum<'a>],
first_arg_type: BasicTypeEnum,
first_arg_type: Option<BasicTypeEnum>,
ns: &Namespace,
) -> Option<BasicValueEnum<'a>> {
unimplemented!()
Expand Down
27 changes: 27 additions & 0 deletions src/sema/builtin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1826,6 +1826,33 @@ impl Namespace {
}],
self,
),
// caller_is_root API
Function::new(
loc,
loc,
pt::Identifier {
name: "caller_is_root".to_string(),
loc,
},
None,
Vec::new(),
pt::FunctionTy::Function,
Some(pt::Mutability::View(loc)),
pt::Visibility::Public(Some(loc)),
vec![],
vec![Parameter {
loc,
id: Some(identifier("caller_is_root")),
ty: Type::Bool,
ty_loc: Some(loc),
readonly: false,
indexed: false,
infinite_size: false,
recursive: false,
annotation: None,
}],
self,
),
] {
func.has_body = true;
let func_no = self.functions.len();
Expand Down
14 changes: 7 additions & 7 deletions tests/lir_tests/convert_lir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,7 @@ fn test_assertion_using_require() {
assert_polkadot_lir_str_eq(
src,
0,
r#"public function sol#3 Test::Test::function::test__int32 (int32):
r#"public function sol#4 Test::Test::function::test__int32 (int32):
block#0 entry:
int32 %num = int32(arg#0);
bool %temp.ssa_ir.1 = int32(%num) > int32(10);
Expand Down Expand Up @@ -690,7 +690,7 @@ fn test_call_1() {
assert_polkadot_lir_str_eq(
src,
0,
r#"public function sol#3 Test::Test::function::test__int32 (int32):
r#"public function sol#4 Test::Test::function::test__int32 (int32):
block#0 entry:
int32 %num = int32(arg#0);
= call function#1(int32(%num));
Expand Down Expand Up @@ -754,7 +754,7 @@ fn test_value_transfer() {
assert_polkadot_lir_str_eq(
src,
0,
r#"public function sol#3 Test::Test::function::transfer__address_uint128 (uint8[32], uint128):
r#"public function sol#4 Test::Test::function::transfer__address_uint128 (uint8[32], uint128):
block#0 entry:
uint8[32] %addr = uint8[32](arg#0);
uint128 %amount = uint128(arg#1);
Expand Down Expand Up @@ -928,7 +928,7 @@ fn test_keccak256() {
assert_polkadot_lir_str_eq(
src,
0,
r#"public function sol#3 b::b::function::add__string_address (ptr<struct.vector<uint8>>, uint8[32]):
r#"public function sol#4 b::b::function::add__string_address (ptr<struct.vector<uint8>>, uint8[32]):
block#0 entry:
ptr<struct.vector<uint8>> %name = ptr<struct.vector<uint8>>(arg#0);
uint8[32] %addr = uint8[32](arg#1);
Expand Down Expand Up @@ -960,7 +960,7 @@ fn test_internal_function_cfg() {
assert_polkadot_lir_str_eq(
src,
1,
r#"public function sol#4 A::A::function::bar__uint256 (uint256) returns (uint256):
r#"public function sol#5 A::A::function::bar__uint256 (uint256) returns (uint256):
block#0 entry:
uint256 %b = uint256(arg#0);
ptr<function (uint256) returns (uint256)> %temp.ssa_ir.6 = function#0;
Expand Down Expand Up @@ -1124,14 +1124,14 @@ fn test_constructor() {
assert_polkadot_lir_str_eq(
src,
0,
r#"public function sol#3 B::B::function::test__uint256 (uint256):
r#"public function sol#4 B::B::function::test__uint256 (uint256):
block#0 entry:
uint256 %a = uint256(arg#0);
ptr<struct.vector<uint8>> %abi_encoded.temp.18 = alloc ptr<struct.vector<uint8>>[uint32(36)];
uint32 %temp.ssa_ir.20 = uint32 hex"58_16_c4_25";
write_buf ptr<struct.vector<uint8>>(%abi_encoded.temp.18) offset:uint32(0) value:uint32(%temp.ssa_ir.20);
write_buf ptr<struct.vector<uint8>>(%abi_encoded.temp.18) offset:uint32(4) value:uint256(%a);
uint32 %success.temp.17, uint8[32] %temp.16 = constructor(no: 5, contract_no:1) salt:_ value:_ gas:uint64(0) address:_ seeds:_ encoded-buffer:ptr<struct.vector<uint8>>(%abi_encoded.temp.18) accounts:absent
uint32 %success.temp.17, uint8[32] %temp.16 = constructor(no: 6, contract_no:1) salt:_ value:_ gas:uint64(0) address:_ seeds:_ encoded-buffer:ptr<struct.vector<uint8>>(%abi_encoded.temp.18) accounts:absent
switch uint32(%success.temp.17):
case: uint32(0) => block#1,
case: uint32(2) => block#2
Expand Down
12 changes: 12 additions & 0 deletions tests/polkadot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,8 @@ fn read_hash(mem: &[u8], ptr: u32) -> Hash {
/// Host functions mock the original implementation, refer to the [pallet docs][1] for more information.
///
/// [1]: https://docs.rs/pallet-contracts/latest/pallet_contracts/api_doc/index.html
///
/// Address `[0; u8]` is considered the root account.
#[wasm_host]
impl Runtime {
#[seal(0)]
Expand Down Expand Up @@ -787,6 +789,11 @@ impl Runtime {
.into())
}

#[seal(0)]
fn caller_is_root() -> Result<u32, Trap> {
Ok((vm.accounts[vm.caller_account].address == [0; 32]).into())
}

#[seal(0)]
fn set_code_hash(code_hash_ptr: u32) -> Result<u32, Trap> {
let hash = read_hash(mem, code_hash_ptr);
Expand Down Expand Up @@ -818,6 +825,11 @@ impl MockSubstrate {
Ok(())
}

/// Overwrites the address at asssociated `account` index with the given `address`.
pub fn set_account_address(&mut self, account: usize, address: [u8; 32]) {
self.0.data_mut().accounts[account].address = address;
}

/// Specify the caller account index for the next function or constructor call.
pub fn set_account(&mut self, index: usize) {
self.0.data_mut().account = index;
Expand Down
21 changes: 21 additions & 0 deletions tests/polkadot_tests/builtins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -845,3 +845,24 @@ fn set_code_hash() {
runtime.function("count", vec![]);
assert_eq!(runtime.output(), 1u32.encode());
}

#[test]
fn caller_is_root() {
let mut runtime = build_solidity(
r#"
import { caller_is_root } from "polkadot";
contract Test {
function test() public view returns (bool) {
return caller_is_root();
}
}"#,
);

runtime.function("test", runtime.0.data().accounts[0].address.to_vec());
assert_eq!(runtime.output(), false.encode());

// Set the caller address to [0; 32] which is the mock VM root account
runtime.set_account_address(0, [0; 32]);
runtime.function("test", [0; 32].to_vec());
assert_eq!(runtime.output(), true.encode());
}
Loading