Skip to content

Commit

Permalink
fix ast2builder function call type bug
Browse files Browse the repository at this point in the history
  • Loading branch information
t81lal committed Jun 13, 2024
1 parent 3261ec7 commit 170633f
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 18 deletions.
42 changes: 24 additions & 18 deletions src/solidity_parser/ast/ast2builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,19 @@ def get_expr_type(self, expr: solnodes1.Expr | soltypes.Type, allow_multiple=Fal
# lookup a single unqualified Ident in the current scope(expr.scope). Note this path ISN'T taken for
# qualified lookups (e.g. x.y)
if allow_multiple:
# return all matching symbol types, TODO: should use ACCEPT_INHERITABLE here also?
return [self.symbol_to_ast2_type(s, function_callee=function_callee) for s in expr.scope.find(text)]
# return all matching symbol types
symbols = expr.scope.find(text)
any_or_all_funcs = self.any_or_all([isinstance(s.value, symtab.ModFunErrEvtScope) for s in symbols])
self.error_handler.assert_error('Expected any or all function callees', any_or_all_funcs)

if function_callee and any_or_all_funcs:
# less sophisticated way compared to refine_expr but we need to get the function at the top of
# the hierarchy (symtab returns them all)
function_chain = self.builder.is_declaration_chain(symbols)
self.error_handler.assert_error('Multiple matching functions must be in a hierarchy chain', function_chain)
return [self.symbol_to_ast2_type(symbols[0], function_callee=function_callee)]

return [self.symbol_to_ast2_type(s, function_callee=function_callee) for s in symbols]
else:
inheritable_predicate = symtab.ACCEPT_INHERITABLE(expr.scope)
symbols = expr.scope.find(expr.text, predicate=inheritable_predicate)
Expand Down Expand Up @@ -1317,12 +1328,9 @@ def is_type_call(s):
# i.e. we might have
# B is A { f() } A { f() }
# So we resolve to B.f
symbol_sources = [self.get_declaring_contract_scope_in_scope(candidate[0]) for candidate in bucket_candidates]
# if symbol_sources[0] == source, we matched multiple symbols in the same bucket that aren't in an
# override chain, e.g. multiple matches in the same contract
are_sub_contracts = all([self.is_subcontract(symbol_sources[0], source) and not symbol_sources[0] == source for source in symbol_sources[1:]])
are_sub_contracts = self.is_declaration_chain([candidate[0] for candidate in bucket_candidates])
if are_sub_contracts:
aliases = ', '.join([s.aliases[0] for s in symbol_sources])
aliases = ', '.join([self.get_declaring_contract_scope_in_scope(c[0]).aliases[0] for c in bucket_candidates])
logging.getLogger('AST2').debug(f'Base chain: {aliases} @ {expr.location}')
candidates.append((c.base, *bucket_candidates[0])) # type: ignore
else:
Expand Down Expand Up @@ -1585,6 +1593,11 @@ def get_function_callee_buckets(self, symbols: List[symtab.Symbol]):

return [Builder.FunctionCallee(base, symbols) for base, symbols in new_buckets.items()]

def is_declaration_chain(self, function_symbols):
symbol_sources = [self.get_declaring_contract_scope_in_scope(s) for s in function_symbols]
are_sub_contracts = all([self.is_subcontract(symbol_sources[0], source) for source in symbol_sources[1:]])
return are_sub_contracts

@link_with_ast1
def refine_expr(self, expr: solnodes1.Expr, is_function_callee=False, allow_type=False, allow_tuple_exprs=False,
allow_multiple_exprs=False, allow_none=True, allow_stmt=False, is_argument=False,
Expand Down Expand Up @@ -1834,11 +1847,8 @@ def z():
# TODO: can this be ambiguous or does the reference always select a single function
if len(member_symbols) == 1:
# Func pointer load e.g. this is the first param in abi.encodeCall(A.f, ...)
symbol_sources = [self.get_declaring_contract_scope_in_scope(d) for d in member_symbols[0]]
are_sub_contracts = all(
[self.is_subcontract(symbol_sources[0], source) for source in symbol_sources[1:]])
self.error_handler.assert_error(f'{expr} has too many target definitions ({len(member_symbols[0])})',
are_sub_contracts)
are_sub_contracts = self.is_declaration_chain(member_symbols[0])
self.error_handler.assert_error(f'{expr} has too many target definitions ({len(member_symbols[0])})',are_sub_contracts)
func_sym = member_symbols[0][0]
if isinstance(func_sym.value, solnodes1.FunctionDefinition):
return solnodes2.GetFunctionPointer(nodebase.Ref(func_sym.value.ast2_node))
Expand Down Expand Up @@ -1911,9 +1921,7 @@ def z():

if directly_referenced_callables[0]:
# check that functions are part of a chain
symbol_sources = [self.get_declaring_contract_scope_in_scope(s) for s in callee.symbols]
are_sub_contracts = all(
[self.is_subcontract(symbol_sources[0], source) for source in symbol_sources[1:]])
are_sub_contracts = self.is_declaration_chain(callee.symbols)
self.error_handler.assert_error(f'Too many target definitions ({len(callee.symbols)})', are_sub_contracts)

member_symbol = callee.symbols[0]
Expand Down Expand Up @@ -2082,9 +2090,7 @@ def modifier(self, node: solnodes1.Modifier):
mod_defs = node.scope.find(node.name.text)
if len(mod_defs) > 1:
# If we have multiple matches, check that they are part of an override chain and pick the first one
symbol_sources = [self.get_declaring_contract_scope_in_scope(d) for d in mod_defs]
are_sub_contracts = all(
[self.is_subcontract(symbol_sources[0], source) for source in symbol_sources[1:]])
are_sub_contracts = self.is_declaration_chain(mod_defs)
self.error_handler.assert_error(f'{node.name.text} has too many target definitions ({len(mod_defs)})', are_sub_contracts)

target = self.load_non_top_level_if_required(mod_defs[0].value)
Expand Down
8 changes: 8 additions & 0 deletions test/solidity_parser/ast/test_ast2builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,14 @@ def test_get_expr_type_member_not_found(self):
self.assertEqual(self.type_helper.get_expr_type(exprs[0]), [])
self.assertEqual(self.type_helper.get_expr_type(exprs[0], allow_multiple=True), [])

def test_get_expr_type_function_call_hierarchy(self):
file_scope = self.symtab_builder.process_or_find_from_base_dir('HierarchyFunctions.sol')
ast1_nodes = file_scope.value

expr = [c for u in ast1_nodes if u for c in u.get_all_children() if isinstance(c, solnodes1.CallFunction) and str(c.callee) == 'hasFoo'][0]

self.assertEqual(self.type_helper.get_function_expr_type(expr), soltypes.BoolType())

def test_get_expr_type_float(self):
file_scope = self.symtab_builder.process_or_find_from_base_dir('float_type.sol')
ast1_nodes = file_scope.value
Expand Down
21 changes: 21 additions & 0 deletions testcases/type_tests/HierarchyFunctions.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
pragma solidity 0.8.20;

interface IBar {
function hasFoo(bytes32 f, address account) external view returns (bool);
}

abstract contract Bar is IBar {
function hasFoo(bytes32 f, address account) public view virtual returns (bool) {
return false;
}
}

contract HierarchyTest is Bar {

bytes32 public constant FOO = keccak256("FOO");

modifier testMyFoo() {
require(hasFoo(FOO, msg.sender), "Foofail");
_;
}
}

0 comments on commit 170633f

Please sign in to comment.