Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(pkg/scale): encoding and decoding of maps in scale #2894

Merged
merged 11 commits into from
Oct 21, 2022
30 changes: 30 additions & 0 deletions pkg/scale/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ func (ds *decodeState) unmarshal(dstv reflect.Value) (err error) {
err = ds.decodeArray(dstv)
case reflect.Slice:
err = ds.decodeSlice(dstv)
case reflect.Map:
err = ds.decodeMap(dstv)
default:
err = fmt.Errorf("%w: %T", ErrUnsupportedType, in)
}
Expand Down Expand Up @@ -417,6 +419,34 @@ func (ds *decodeState) decodeArray(dstv reflect.Value) (err error) {
return
}

func (ds *decodeState) decodeMap(dstv reflect.Value) (err error) {
axaysagathiya marked this conversation as resolved.
Show resolved Hide resolved
numberOfTuples, err := ds.decodeLength()
if err != nil {
return fmt.Errorf("decoding length: %w", err)
}
in := dstv.Interface()

for i := uint(0); i < numberOfTuples; i++ {
tempKeyType := reflect.TypeOf(in).Key()
tempKey := reflect.New(tempKeyType).Elem()
err = ds.unmarshal(tempKey)
if err != nil {
return fmt.Errorf("decoding key %d of %d: %w", i+1, numberOfTuples, err)
}

tempElemType := reflect.TypeOf(in).Elem()
tempElem := reflect.New(tempElemType).Elem()
err = ds.unmarshal(tempElem)
if err != nil {
return fmt.Errorf("decoding value %d of %d: %w", i+1, numberOfTuples, err)
}

dstv.SetMapIndex(tempKey, tempElem)
}

return nil
}

// decodeStruct decodes a byte array representing a SCALE tuple. The order of data is
// determined by the source tuple in rust, or the struct field order in a go struct
func (ds *decodeState) decodeStruct(dstv reflect.Value) (err error) {
Expand Down
96 changes: 95 additions & 1 deletion pkg/scale/decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,93 @@ func Test_decodeState_decodeSlice(t *testing.T) {
}
}

type user struct {
Active bool
Username string
Email string
SignInCount uint64
}

func Test_decodeState_decodeMap(t *testing.T) {
mapTests1 := []struct {
name string
input []byte
wantErr bool
expectedOutput map[int8][]byte
}{
{
name: "testing a map of int8 to a byte array 1",
input: []byte{4, 2, 44, 115, 111, 109, 101, 32, 115, 116, 114, 105, 110, 103},
expectedOutput: map[int8][]byte{2: []byte("some string")},
},
{
name: "testing a map of int8 to a byte array 2",
input: []byte{
8, 2, 44, 115, 111, 109, 101, 32, 115, 116, 114, 105, 110, 103, 16, 44, 108, 111, 114, 101, 109, 32,
105, 112, 115, 117, 109,
},
expectedOutput: map[int8][]byte{
2: []byte("some string"),
16: []byte("lorem ipsum"),
},
},
}

for _, tt := range mapTests1 {
tt := tt
t.Run(tt.name, func(t *testing.T) {
actualOutput := make(map[int8][]byte)
if err := Unmarshal(tt.input, &actualOutput); (err != nil) != tt.wantErr {
t.Errorf("decodeState.unmarshal() error = %v, wantErr %v", err, tt.wantErr)
}

if !reflect.DeepEqual(actualOutput, tt.expectedOutput) {
t.Errorf("decodeState.unmarshal() = %v, want %v", actualOutput, tt.expectedOutput)
}
})
}

mapTests2 := []struct {
name string
input []byte
wantErr bool
expectedOutput map[string]user
}{
{
name: "testing a map of string to struct",
input: []byte{8, 28, 115, 116, 114, 105, 110, 103, 49, 1, 20, 108, 111, 114, 101, 109, 60, 108, 111, 114, 101, 109, 64, 105, 112, 115, 117, 109, 46, 111, 114, 103, 1, 0, 0, 0, 0, 0, 0, 0, 28, 115, 116, 114, 105, 110, 103, 50, 0, 16, 106, 111, 104, 110, 56, 106, 97, 99, 107, 64, 103, 109, 97, 105, 108, 46, 99, 111, 109, 73, 0, 0, 0, 0, 0, 0, 0}, //nolint:lll
expectedOutput: map[string]user{
"string1": {
Active: true,
Username: "lorem",
Email: "[email protected]",
SignInCount: 1,
},
"string2": {
Active: false,
Username: "john",
Email: "[email protected]",
SignInCount: 73,
},
},
},
}

for _, tt := range mapTests2 {
tt := tt
t.Run(tt.name, func(t *testing.T) {
actualOutput := make(map[string]user)
if err := Unmarshal(tt.input, &actualOutput); (err != nil) != tt.wantErr {
t.Errorf("decodeState.unmarshal() error = %v, wantErr %v", err, tt.wantErr)
}

if !reflect.DeepEqual(actualOutput, tt.expectedOutput) {
t.Errorf("decodeState.unmarshal() = %v, want %v", actualOutput, tt.expectedOutput)
}
})
}
}

func Test_unmarshal_optionality(t *testing.T) {
var ptrTests tests
for _, t := range append(tests{}, allTests...) {
Expand Down Expand Up @@ -167,7 +254,14 @@ func Test_unmarshal_optionality(t *testing.T) {
t.Errorf("decodeState.unmarshal() = %s", diff)
}
default:
dst := reflect.New(reflect.TypeOf(tt.in)).Interface()
jimjbrettj marked this conversation as resolved.
Show resolved Hide resolved
var dst interface{}

if reflect.TypeOf(tt.in).Kind().String() == "map" {
dst = &(map[int8][]byte{})
} else {
dst = reflect.New(reflect.TypeOf(tt.in)).Interface()
}

if err := Unmarshal(tt.want, &dst); (err != nil) != tt.wantErr {
t.Errorf("decodeState.unmarshal() error = %v, wantErr %v", err, tt.wantErr)
return
Expand Down
37 changes: 37 additions & 0 deletions pkg/scale/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"io"
"math/big"
"reflect"
"sort"
)

// Encoder scale encodes to a given io.Writer.
Expand Down Expand Up @@ -106,6 +107,8 @@ func (es *encodeState) marshal(in interface{}) (err error) {
err = es.encodeArray(in)
case reflect.Slice:
err = es.encodeSlice(in)
case reflect.Map:
err = es.encodeMap(in)
default:
err = fmt.Errorf("%w: %T", ErrUnsupportedType, in)
}
Expand Down Expand Up @@ -223,6 +226,40 @@ func (es *encodeState) encodeArray(in interface{}) (err error) {
return
}

func (es *encodeState) encodeMap(in interface{}) (err error) {
v := reflect.ValueOf(in)
err = es.encodeLength(v.Len())
if err != nil {
return fmt.Errorf("encoding length: %w", err)
}

mapKeys := v.MapKeys()

sort.Slice(mapKeys, func(i, j int) bool {
keyByteOfI, _ := Marshal(mapKeys[i].Interface())
keyByteOfJ, _ := Marshal(mapKeys[j].Interface())
return bytes.Compare(keyByteOfI, keyByteOfJ) < 0
})
Comment on lines +238 to +242
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we doing this? Is this just so we can have deterministic tests?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@axaysagathiya can you create another PR to only do this in the tests? This is not required when encoding maps in production, and it's not very performant. I'd suggest adding a sortKeys parameter to the encodeMap function signature.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, my bad!

This is not required when encoding maps

This is not just needed in test, this is needed in encoding as well. keys in go maps are unordered. so running encode function on the same go map could result into different encodings. Rust equivalent of go map is BTree map, where keys are sorted. Scale encode function in the Rust implementation does not produces different values on running it multiple times.

So, in order to make sure that our scale package and Rust's scale package encodes a map to the same value sorting is necessary.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just because the BTreeMap has deterministic sorting of keys, doesn't mean that the parity-scale-codec isn't able to decode maps that were encoded with keys in non-deterministic order. In go it's understood that the order of the keys is non-deterministic. In this case, we should be testing that we can encode a map and decode into a map, and assert that the maps are equal. We could also ensure that decoding into a rust BTreeMap via the parity-scale-codec works as expected. To make the tests more deterministic, you could provide a map with a single key or a map with two keys, and just assert that it's equal to one of the two encodings and remove the sorting logic altogether.


for _, key := range mapKeys {
err = es.marshal(key.Interface())
if err != nil {
return fmt.Errorf("encoding map key: %w", err)
}

mapValue := v.MapIndex(key)
if !mapValue.CanInterface() {
continue
}
Comment on lines +251 to +253
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we return an error in that case, to avoid silently discarding a map value when encoding?

@timwu20

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Def wait for Tims answer, but this happens at other places in this file as well (i.e. structs)

Copy link
Contributor

@timwu20 timwu20 Oct 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CanInterface is essentially determining if this is a public attribute of a struct or public method. I think it should be fine in this case. This will almost always return true in the case we're trying to decode into a map value. I wonder what cases this would return false though, maybe we can provide a test case.

@kishansagathiya please don't merge the PR with unresolved conversations. My bad for not getting to this earlier.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realised about this comment few moments after merging it. Wasn't deliberate.


err = es.marshal(mapValue.Interface())
if err != nil {
return fmt.Errorf("encoding map value: %w", err)
}
}
return nil
}

// encodeBigInt performs the same encoding as encodeInteger, except on a big.Int.
// if 2^30 <= n < 2^536 write
// [lower 2 bits of first byte = 11] [upper 6 bits of first byte = # of bytes following less 4]
Expand Down
40 changes: 39 additions & 1 deletion pkg/scale/encode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -909,9 +909,28 @@ var (
},
}

mapTests = tests{
{
name: "testMap1",
in: map[int8][]byte{2: []byte("some string")},
want: []byte{4, 2, 44, 115, 111, 109, 101, 32, 115, 116, 114, 105, 110, 103},
jimjbrettj marked this conversation as resolved.
Show resolved Hide resolved
},
{
name: "testMap2",
in: map[int8][]byte{
2: []byte("some string"),
16: []byte("lorem ipsum"),
},
want: []byte{
8, 2, 44, 115, 111, 109, 101, 32, 115, 116, 114, 105, 110, 103, 16, 44, 108, 111, 114, 101, 109, 32,
105, 112, 115, 117, 109,
},
},
}

allTests = newTests(
fixedWidthIntegerTests, variableWidthIntegerTests, stringTests,
boolTests, structTests, sliceTests, arrayTests,
boolTests, structTests, sliceTests, arrayTests, mapTests,
varyingDataTypeTests,
)
)
Expand Down Expand Up @@ -1096,6 +1115,25 @@ func Test_encodeState_encodeArray(t *testing.T) {
}
}

func Test_encodeState_encodeMap(t *testing.T) {
for _, tt := range mapTests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
buffer := bytes.NewBuffer(nil)
es := &encodeState{
Writer: buffer,
fieldScaleIndicesCache: cache,
}
if err := es.marshal(tt.in); (err != nil) != tt.wantErr {
t.Errorf("encodeState.encodeMap() error = %v, wantErr %v", err, tt.wantErr)
}
if !reflect.DeepEqual(buffer.Bytes(), tt.want) {
t.Errorf("encodeState.encodeMap() = %v, want %v", buffer.Bytes(), tt.want)
}
})
}
}

func Test_marshal_optionality(t *testing.T) {
var ptrTests tests
for i := range allTests {
Expand Down