Skip to content

Commit

Permalink
feat: support _replace method on ARC4Struct
Browse files Browse the repository at this point in the history
  • Loading branch information
achidlow committed Jan 17, 2025
1 parent 9e78626 commit 3b1268d
Show file tree
Hide file tree
Showing 42 changed files with 1,203 additions and 615 deletions.
4 changes: 2 additions & 2 deletions examples/sizes.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
arc4_types/Arc4RefTypes 85 46 - | 32 27 -
arc4_types/Arc4StringTypes 349 35 - | 149 13 -
arc4_types/Arc4StructsFromAnotherModule 67 12 - | 49 6 -
arc4_types/Arc4StructsType 386 48 - | 258 16 -
arc4_types/Arc4StructsType 424 48 - | 284 16 -
arc4_types/Arc4TuplesType 938 8 - | 644 4 -
arc4_types/MutableParams2 318 193 48 | 185 92 23
arc_28/EventEmitter 172 124 102 | 92 58 48
Expand Down Expand Up @@ -138,4 +138,4 @@
unssa/UnSSA 420 266 - | 237 153 -
voting/VotingRoundApp 1584 1426 1415 | 725 624 625
with_reentrancy/WithReentrancy 245 214 - | 126 108 -
Total 70482 39117 36225 | 33270 18249 16976
Total 70520 39117 36225 | 33296 18249 16976
54 changes: 52 additions & 2 deletions src/puyapy/awst_build/eb/arc4/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

from puya import log
from puya.awst import wtypes
from puya.awst.nodes import Expression, FieldExpression, NewStruct
from puya.awst.nodes import Copy, Expression, FieldExpression, NewStruct
from puya.parse import SourceLocation
from puyapy.awst_build import pytypes
from puyapy.awst_build.eb import _expect as expect
from puyapy.awst_build.eb._base import NotIterableInstanceExpressionBuilder
from puyapy.awst_build.eb._base import FunctionBuilder, NotIterableInstanceExpressionBuilder
from puyapy.awst_build.eb._bytes_backed import (
BytesBackedInstanceExpressionBuilder,
BytesBackedTypeBuilder,
Expand Down Expand Up @@ -78,6 +78,8 @@ def member_access(self, name: str, location: SourceLocation) -> NodeBuilder:
return builder_for_instance(field, result_expr)
case "copy":
return CopyBuilder(self.resolve(), location, self.pytype)
case "_replace":
return _Replace(self, self.pytype, location)
case _:
return super().member_access(name, location)

Expand All @@ -88,3 +90,51 @@ def compare(

def bool_eval(self, location: SourceLocation, *, negate: bool = False) -> InstanceBuilder:
return constant_bool_and_error(value=True, location=location, negate=negate)


class _Replace(FunctionBuilder):
def __init__(
self,
instance: ARC4StructExpressionBuilder,
struct_type: pytypes.StructType,
location: SourceLocation,
):
super().__init__(location)
self.instance = instance
self.struct_type = struct_type

@typing.override
def call(
self,
args: Sequence[NodeBuilder],
arg_kinds: list[mypy.nodes.ArgKind],
arg_names: list[str | None],
location: SourceLocation,
) -> InstanceBuilder:
pytype = self.struct_type
field_mapping, _ = get_arg_mapping(
optional_kw_only=list(pytype.fields),
args=args,
arg_names=arg_names,
call_location=location,
raise_on_missing=False,
)
base_expr = self.instance.single_eval().resolve()
values = dict[str, Expression]()
for field_name, field_pytype in pytype.fields.items():
new_value = field_mapping.get(field_name)
if new_value is not None:
item_builder = expect.argument_of_type_else_dummy(new_value, field_pytype)
item = item_builder.resolve()
else:
field_wtype = field_pytype.checked_wtype(location)
item = FieldExpression(base=base_expr, name=field_name, source_location=location)
if not field_wtype.immutable:
logger.error(
f"mutable field {field_name!r} requires explicit copy", location=location
)
# implicitly create a copy node so that there is only one error
item = Copy(value=item, source_location=location)
values[field_name] = item
new_tuple = NewStruct(values=values, wtype=pytype.wtype, source_location=location)
return ARC4StructExpressionBuilder(new_tuple, pytype)
2 changes: 1 addition & 1 deletion src/puyapy/awst_build/pytypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ class StructType(RuntimeType):
converter=immutabledict, validator=[attrs.validators.min_len(1)]
)
frozen: bool
wtype: wtypes.WType
wtype: wtypes.ARC4Struct | wtypes.WStructType
source_location: SourceLocation | None
generic: None = None
desc: str | None = None
Expand Down
5 changes: 5 additions & 0 deletions stubs/algopy-stubs/arc4.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,11 @@ class Struct(metaclass=_StructMeta):
def copy(self) -> typing.Self:
"""Create a copy of this struct"""

def _replace(self, **kwargs: typing.Any) -> typing.Self: # type: ignore[misc]
"""Return a new instance of the struct replacing specified fields with new values.
Note that any mutable fields must be explicitly copied to avoid aliasing."""

class ARC4Client(typing.Protocol):
"""Used to provide typed method signatures for ARC4 contracts"""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"sources": [
"../structs.py"
],
"mappings": ";AAsC+B;;;;;;;;;;;;AAAX;;;;;;;;;;AACR;AADZ;AAAA;;;;;;;;;;;;AAGgB;;;AAER;AAcO;;AAAP",
"mappings": ";AAsC+B;;;;;;;;;;;;AAAX;;;;;;;;;;AACR;AADZ;AAAA;;;;;;;;;;;;AAGgB;;;AAER;AAkBO;;AAAP",
"op_pc_offset": 0,
"pc_events": {
"1": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ main_after_for@4:
// arc4_types/structs.py:44
// log(flags.bytes)
log
// arc4_types/structs.py:58
// arc4_types/structs.py:62
// return True
pushint 1 // 1
return
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"sources": [
"../structs.py"
],
"mappings": ";AA4De;;AAAP",
"mappings": ";AAgEe;;AAAP",
"op_pc_offset": 0,
"pc_events": {
"1": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

// test_cases.arc4_types.structs.Arc4StructsTypeContract.clear_state_program() -> uint64:
main:
// arc4_types/structs.py:61
// arc4_types/structs.py:65
// return True
pushint 1 // 1
return
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,24 @@ main test_cases.arc4_types.structs.Arc4StructsTypeContract.approval_program:
let no_copy#0: bytes = immutable#0
let tmp%5#0: bool = (== no_copy#0 immutable#0)
(assert tmp%5#0)
let tmp%6#0: bytes = (extract3 immutable#0 0u 8u) // on error: Index access is out of bounds
let current_tail_offset%7#0: uint64 = 16u
let encoded_tuple_buffer%20#0: bytes = 0x
let encoded_tuple_buffer%21#0: bytes = (concat encoded_tuple_buffer%20#0 tmp%6#0)
let encoded_tuple_buffer%22#0: bytes = (concat encoded_tuple_buffer%21#0 0x000000000000007b)
let immutable2#0: bytes = encoded_tuple_buffer%22#0
let reinterpret_biguint%0#0: biguint = (extract3 immutable2#0 8u 8u) // on error: Index access is out of bounds
let reinterpret_biguint%1#0: biguint = 0x000000000000007b
let tmp%7#0: bool = (b== reinterpret_biguint%0#0 reinterpret_biguint%1#0)
(assert tmp%7#0)
let reinterpret_biguint%2#0: biguint = (extract3 immutable2#0 0u 8u) // on error: Index access is out of bounds
let reinterpret_biguint%3#0: biguint = (extract3 immutable#0 0u 8u) // on error: Index access is out of bounds
let tmp%8#0: bool = (b== reinterpret_biguint%2#0 reinterpret_biguint%3#0)
(assert tmp%8#0)
return 1u

subroutine test_cases.arc4_types.structs.add(v1: bytes, v2: bytes) -> <bytes, bytes, bytes>:
block@0: // L64
block@0: // L68
let v1%is_original#0: bool = 1u
let v1%out#0: bytes = v1#0
let v2%is_original#0: bool = 1u
Expand All @@ -122,15 +136,15 @@ subroutine test_cases.arc4_types.structs.add(v1: bytes, v2: bytes) -> <bytes, by
return encoded_tuple_buffer%2#0 v1#0 v2#0

subroutine test_cases.arc4_types.structs.add_decimal(x: bytes, y: bytes) -> bytes:
block@0: // L86
block@0: // L90
let tmp%0#0: uint64 = (btoi x#0)
let tmp%1#0: uint64 = (btoi y#0)
let tmp%2#0: uint64 = (+ tmp%0#0 tmp%1#0)
let tmp%3#0: bytes = (itob tmp%2#0)
return tmp%3#0

subroutine test_cases.arc4_types.structs.check(flags: bytes) -> bytes:
block@0: // L72
block@0: // L76
let flags%is_original#0: bool = 1u
let flags%out#0: bytes = flags#0
let is_true%0#0: uint64 = (getbit flags#0 0u)
Expand All @@ -154,7 +168,7 @@ subroutine test_cases.arc4_types.structs.check(flags: bytes) -> bytes:
return flags%out#0

subroutine test_cases.arc4_types.structs.nested_decode(vector_flags: bytes) -> bytes:
block@0: // L80
block@0: // L84
let vector_flags%is_original#0: bool = 1u
let vector_flags%out#0: bytes = vector_flags#0
let tmp%0#0: bytes = (extract3 vector_flags#0 0u 16u) // on error: Index access is out of bounds
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ main test_cases.arc4_types.structs.Arc4StructsTypeContract.approval_program:
return 1u

subroutine test_cases.arc4_types.structs.add_decimal(x: bytes, y: bytes) -> bytes:
block@0: // L86
block@0: // L90
let tmp%0#0: uint64 = (btoi x#0)
let tmp%1#0: uint64 = (btoi y#0)
let tmp%2#0: uint64 = (+ tmp%0#0 tmp%1#0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ main_after_for@4:
// arc4_types/structs.py:44
// log(flags.bytes)
log (𝕗) val#2,loop_counter%0#0 |
// arc4_types/structs.py:58
// arc4_types/structs.py:62
// return True
int 1 (𝕗) val#2,loop_counter%0#0 | 1
return (𝕗) val#2,loop_counter%0#0 |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ main_after_for@4:
// arc4_types/structs.py:44
// log(flags.bytes)
log
// arc4_types/structs.py:58
// arc4_types/structs.py:62
// return True
int 1 1
return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ main_after_for@4:
// arc4_types/structs.py:44
// log(flags.bytes)
log
// arc4_types/structs.py:58
// arc4_types/structs.py:62
// return True
int 1 1
return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ main_after_for@4:
// arc4_types/structs.py:44
// log(flags.bytes)
log
// arc4_types/structs.py:58
// arc4_types/structs.py:62
// return True
int 1 1
return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ main_after_for@4:
// arc4_types/structs.py:44
// log(flags.bytes)
log
// arc4_types/structs.py:58
// arc4_types/structs.py:62
// return True
int 1 1
return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ main_after_for@4:
// arc4_types/structs.py:44
// log(flags.bytes)
log (𝕗) val#2,loop_counter%0#0 |
// arc4_types/structs.py:58
// arc4_types/structs.py:62
// return True
int 1 (𝕗) val#2,loop_counter%0#0 | 1
return (𝕗) val#2,loop_counter%0#0 |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ main_after_for@4:
// arc4_types/structs.py:44
// log(flags.bytes)
log (𝕗) val#2,loop_counter%0#0 |
// arc4_types/structs.py:58
// arc4_types/structs.py:62
// return True
int 1 (𝕗) val#2,loop_counter%0#0 | 1
return (𝕗) val#2,loop_counter%0#0 |
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
main test_cases.arc4_types.structs.Arc4StructsTypeContract.clear_state_program:
block@0: // L60
block@0: // L64
return 1u
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
main test_cases.arc4_types.structs.Arc4StructsTypeContract.clear_state_program:
block@0: // L60
block@0: // L64
return 1u
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Op Stack (out)
// test_cases.arc4_types.structs.Arc4StructsTypeContract.clear_state_program() -> uint64:
main:
// arc4_types/structs.py:61
// arc4_types/structs.py:65
// return True
int 1 1
return
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Op Stack (out)
// test_cases.arc4_types.structs.Arc4StructsTypeContract.clear_state_program() -> uint64:
main:
// arc4_types/structs.py:61
// arc4_types/structs.py:65
// return True
int 1 1
return
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Op Stack (out)
// test_cases.arc4_types.structs.Arc4StructsTypeContract.clear_state_program() -> uint64:
main:
// arc4_types/structs.py:61
// arc4_types/structs.py:65
// return True
int 1 1
return
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Op Stack (out)
// test_cases.arc4_types.structs.Arc4StructsTypeContract.clear_state_program() -> uint64:
main:
// arc4_types/structs.py:61
// arc4_types/structs.py:65
// return True
int 1 1
return
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Op Stack (out)
// test_cases.arc4_types.structs.Arc4StructsTypeContract.clear_state_program() -> uint64:
main:
// arc4_types/structs.py:61
// arc4_types/structs.py:65
// return True
int 1 1
return
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Op Stack (out)
// test_cases.arc4_types.structs.Arc4StructsTypeContract.clear_state_program() -> uint64:
main:
// arc4_types/structs.py:61
// arc4_types/structs.py:65
// return True
int 1 1
return
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Op Stack (out)
// test_cases.arc4_types.structs.Arc4StructsTypeContract.clear_state_program() -> uint64:
main:
// arc4_types/structs.py:61
// arc4_types/structs.py:65
// return True
int 1 1
return
Expand Down
3 changes: 3 additions & 0 deletions test_cases/arc4_types/out/module.awst
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ contract Arc4StructsTypeContract
immutable: test_cases.arc4_types.structs.FrozenAndImmutable = new test_cases.arc4_types.structs.FrozenAndImmutable(one=12_arc4u64, two=34_arc4u64)
no_copy: test_cases.arc4_types.structs.FrozenAndImmutable = immutable
assert(no_copy == immutable)
immutable2: test_cases.arc4_types.structs.FrozenAndImmutable = new test_cases.arc4_types.structs.FrozenAndImmutable(one=immutable.one, two=123_arc4u64)
assert(reinterpret_cast<biguint>(immutable2.two) == reinterpret_cast<biguint>(123_arc4u64))
assert(reinterpret_cast<biguint>(immutable2.one) == reinterpret_cast<biguint>(immutable.one))
return true
}

Expand Down
Loading

0 comments on commit 3b1268d

Please sign in to comment.