diff --git a/pkg/scale/decode.go b/pkg/scale/decode.go index 45a527f0b8..bcd99c8d1c 100644 --- a/pkg/scale/decode.go +++ b/pkg/scale/decode.go @@ -75,6 +75,11 @@ func Unmarshal(data []byte, dst interface{}) (err error) { return } +// Unmarshaler is the interface for custom SCALE unmarshalling for a given type +type Unmarshaler interface { + UnmarshalSCALE(io.Reader) error +} + // Decoder is used to decode from an io.Reader type Decoder struct { decodeState @@ -108,6 +113,18 @@ type decodeState struct { } func (ds *decodeState) unmarshal(dstv reflect.Value) (err error) { + unmarshalerType := reflect.TypeOf((*Unmarshaler)(nil)).Elem() + if dstv.CanAddr() && dstv.Addr().Type().Implements(unmarshalerType) { + methodVal := dstv.Addr().MethodByName("UnmarshalSCALE") + values := methodVal.Call([]reflect.Value{reflect.ValueOf(ds.Reader)}) + if !values[0].IsNil() { + errIn := values[0].Interface() + err := errIn.(error) + return err + } + return + } + in := dstv.Interface() switch in.(type) { case *big.Int: diff --git a/pkg/scale/decode_test.go b/pkg/scale/decode_test.go index 3309c58a9b..713f3a7bce 100644 --- a/pkg/scale/decode_test.go +++ b/pkg/scale/decode_test.go @@ -5,6 +5,9 @@ package scale import ( "bytes" + "encoding/binary" + "fmt" + "io" "math/big" "reflect" "testing" @@ -535,3 +538,85 @@ func Test_decodeState_decodeUint(t *testing.T) { }) } } + +type myStruct struct { + First uint32 + Middle any + Last uint32 +} + +func (ms *myStruct) UnmarshalSCALE(reader io.Reader) (err error) { + buf := make([]byte, 4) + _, err = reader.Read(buf) + if err != nil { + return + } + ms.First = binary.LittleEndian.Uint32(buf) + + buf = make([]byte, 4) + _, err = reader.Read(buf) + if err != nil { + return + } + ms.Middle = binary.LittleEndian.Uint32(buf) + + buf = make([]byte, 4) + _, err = reader.Read(buf) + if err != nil { + return + } + ms.Last = binary.LittleEndian.Uint32(buf) + return nil +} + +type myStructError struct { + First uint32 + Middle any + Last uint32 +} + +func (mse *myStructError) UnmarshalSCALE(reader io.Reader) (err error) { + err = fmt.Errorf("eh?") + return err +} + +var _ Unmarshaler = &myStruct{} + +func Test_decodeState_Unmarshaller(t *testing.T) { + expected := myStruct{ + First: 1, + Middle: uint32(2), + Last: 3, + } + bytes := MustMarshal(expected) + ms := myStruct{} + Unmarshal(bytes, &ms) + assert.Equal(t, expected, ms) + + type myParentStruct struct { + First uint + Middle myStruct + Last uint + } + expectedParent := myParentStruct{ + First: 1, + Middle: expected, + Last: 3, + } + bytes = MustMarshal(expectedParent) + mps := myParentStruct{} + Unmarshal(bytes, &mps) + assert.Equal(t, expectedParent, mps) +} + +func Test_decodeState_Unmarshaller_Error(t *testing.T) { + expected := myStruct{ + First: 1, + Middle: uint32(2), + Last: 3, + } + bytes := MustMarshal(expected) + mse := myStructError{} + err := Unmarshal(bytes, &mse) + assert.Error(t, err, "eh?") +} diff --git a/pkg/scale/encode.go b/pkg/scale/encode.go index d312b85f91..c9830aef9d 100644 --- a/pkg/scale/encode.go +++ b/pkg/scale/encode.go @@ -47,6 +47,11 @@ func Marshal(v interface{}) (b []byte, err error) { return } +// Marshaler is the interface for custom SCALE marshalling for a given type +type Marshaler interface { + MarshalSCALE() ([]byte, error) +} + // MustMarshal runs Marshal and panics on error. func MustMarshal(v interface{}) (b []byte) { b, err := Marshal(v) @@ -62,6 +67,17 @@ type encodeState struct { } func (es *encodeState) marshal(in interface{}) (err error) { + marshaler, ok := in.(Marshaler) + if ok { + var bytes []byte + bytes, err = marshaler.MarshalSCALE() + if err != nil { + return + } + _, err = es.Write(bytes) + return + } + switch in := in.(type) { case int: err = es.encodeUint(uint(in)) diff --git a/pkg/scale/encode_test.go b/pkg/scale/encode_test.go index 92de411919..8f5d9a60ca 100644 --- a/pkg/scale/encode_test.go +++ b/pkg/scale/encode_test.go @@ -5,6 +5,7 @@ package scale import ( "bytes" + "fmt" "math/big" "reflect" "strings" @@ -1250,3 +1251,25 @@ var byteArray = func(length int) []byte { } return b } + +type myMarshalerType uint64 + +func (mmt myMarshalerType) MarshalSCALE() ([]byte, error) { + return []byte{9, 9, 9}, nil +} + +type myMarshalerTypeError uint64 + +func (mmt myMarshalerTypeError) MarshalSCALE() ([]byte, error) { + return nil, fmt.Errorf("eh?") +} + +func Test_encodeState_Mashaler(t *testing.T) { + bytes := MustMarshal(myMarshalerType(888)) + assert.Equal(t, []byte{9, 9, 9}, bytes) +} + +func Test_encodeState_Mashaler_Error(t *testing.T) { + _, err := Marshal(myMarshalerTypeError(888)) + assert.Error(t, err, "eh?") +}