Skip to content

Commit

Permalink
feat(pkg/scale): Add Marshaler and Unmarshaler interfaces and fun…
Browse files Browse the repository at this point in the history
…ctionality (#3617)
  • Loading branch information
timwu20 authored Dec 4, 2023
1 parent 11b96dc commit 4888ce4
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 0 deletions.
17 changes: 17 additions & 0 deletions pkg/scale/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ func Unmarshal(data []byte, dst interface{}) (err error) {
return
}

// Unmarshaler is the interface for custom SCALE unmarshalling for a given type
type Unmarshaler interface {
UnmarshalSCALE(io.Reader) error
}

// Decoder is used to decode from an io.Reader
type Decoder struct {
decodeState
Expand Down Expand Up @@ -108,6 +113,18 @@ type decodeState struct {
}

func (ds *decodeState) unmarshal(dstv reflect.Value) (err error) {
unmarshalerType := reflect.TypeOf((*Unmarshaler)(nil)).Elem()
if dstv.CanAddr() && dstv.Addr().Type().Implements(unmarshalerType) {
methodVal := dstv.Addr().MethodByName("UnmarshalSCALE")
values := methodVal.Call([]reflect.Value{reflect.ValueOf(ds.Reader)})
if !values[0].IsNil() {
errIn := values[0].Interface()
err := errIn.(error)
return err
}
return
}

in := dstv.Interface()
switch in.(type) {
case *big.Int:
Expand Down
85 changes: 85 additions & 0 deletions pkg/scale/decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ package scale

import (
"bytes"
"encoding/binary"
"fmt"
"io"
"math/big"
"reflect"
"testing"
Expand Down Expand Up @@ -535,3 +538,85 @@ func Test_decodeState_decodeUint(t *testing.T) {
})
}
}

type myStruct struct {
First uint32
Middle any
Last uint32
}

func (ms *myStruct) UnmarshalSCALE(reader io.Reader) (err error) {
buf := make([]byte, 4)
_, err = reader.Read(buf)
if err != nil {
return
}
ms.First = binary.LittleEndian.Uint32(buf)

buf = make([]byte, 4)
_, err = reader.Read(buf)
if err != nil {
return
}
ms.Middle = binary.LittleEndian.Uint32(buf)

buf = make([]byte, 4)
_, err = reader.Read(buf)
if err != nil {
return
}
ms.Last = binary.LittleEndian.Uint32(buf)
return nil
}

type myStructError struct {
First uint32
Middle any
Last uint32
}

func (mse *myStructError) UnmarshalSCALE(reader io.Reader) (err error) {
err = fmt.Errorf("eh?")
return err
}

var _ Unmarshaler = &myStruct{}

func Test_decodeState_Unmarshaller(t *testing.T) {
expected := myStruct{
First: 1,
Middle: uint32(2),
Last: 3,
}
bytes := MustMarshal(expected)
ms := myStruct{}
Unmarshal(bytes, &ms)
assert.Equal(t, expected, ms)

type myParentStruct struct {
First uint
Middle myStruct
Last uint
}
expectedParent := myParentStruct{
First: 1,
Middle: expected,
Last: 3,
}
bytes = MustMarshal(expectedParent)
mps := myParentStruct{}
Unmarshal(bytes, &mps)
assert.Equal(t, expectedParent, mps)
}

func Test_decodeState_Unmarshaller_Error(t *testing.T) {
expected := myStruct{
First: 1,
Middle: uint32(2),
Last: 3,
}
bytes := MustMarshal(expected)
mse := myStructError{}
err := Unmarshal(bytes, &mse)
assert.Error(t, err, "eh?")
}
16 changes: 16 additions & 0 deletions pkg/scale/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ func Marshal(v interface{}) (b []byte, err error) {
return
}

// Marshaler is the interface for custom SCALE marshalling for a given type
type Marshaler interface {
MarshalSCALE() ([]byte, error)
}

// MustMarshal runs Marshal and panics on error.
func MustMarshal(v interface{}) (b []byte) {
b, err := Marshal(v)
Expand All @@ -62,6 +67,17 @@ type encodeState struct {
}

func (es *encodeState) marshal(in interface{}) (err error) {
marshaler, ok := in.(Marshaler)
if ok {
var bytes []byte
bytes, err = marshaler.MarshalSCALE()
if err != nil {
return
}
_, err = es.Write(bytes)
return
}

switch in := in.(type) {
case int:
err = es.encodeUint(uint(in))
Expand Down
23 changes: 23 additions & 0 deletions pkg/scale/encode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package scale

import (
"bytes"
"fmt"
"math/big"
"reflect"
"strings"
Expand Down Expand Up @@ -1250,3 +1251,25 @@ var byteArray = func(length int) []byte {
}
return b
}

type myMarshalerType uint64

func (mmt myMarshalerType) MarshalSCALE() ([]byte, error) {
return []byte{9, 9, 9}, nil
}

type myMarshalerTypeError uint64

func (mmt myMarshalerTypeError) MarshalSCALE() ([]byte, error) {
return nil, fmt.Errorf("eh?")
}

func Test_encodeState_Mashaler(t *testing.T) {
bytes := MustMarshal(myMarshalerType(888))
assert.Equal(t, []byte{9, 9, 9}, bytes)
}

func Test_encodeState_Mashaler_Error(t *testing.T) {
_, err := Marshal(myMarshalerTypeError(888))
assert.Error(t, err, "eh?")
}

0 comments on commit 4888ce4

Please sign in to comment.