From 94ddaa1b0c903703fff89424178dad20387ed3df Mon Sep 17 00:00:00 2001 From: Ryan Tinianov Date: Thu, 2 Nov 2023 16:29:09 -0400 Subject: [PATCH] Add merge and split fields --- pkg/loop/internal/chain_reader_test.go | 6 +- pkg/utils/utils.go | 85 +++++++++++++++++++++++++ pkg/utils/utils_test.go | 87 ++++++++++++++++++++++++++ 3 files changed, 176 insertions(+), 2 deletions(-) diff --git a/pkg/loop/internal/chain_reader_test.go b/pkg/loop/internal/chain_reader_test.go index 733f00ad9..4ebc34e25 100644 --- a/pkg/loop/internal/chain_reader_test.go +++ b/pkg/loop/internal/chain_reader_test.go @@ -2,17 +2,19 @@ package internal import ( "errors" - "github.com/stretchr/testify/assert" - "google.golang.org/grpc/test/bufconn" "net" "sync" "testing" + "github.com/stretchr/testify/assert" + "google.golang.org/grpc/test/bufconn" + "github.com/stretchr/testify/require" "github.com/smartcontractkit/chainlink-relay/pkg/loop/internal/pb" "context" + "github.com/fxamacker/cbor/v2" "github.com/mitchellh/mapstructure" ocrtypes "github.com/smartcontractkit/libocr/offchainreporting2plus/types" diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go index c2bb12ba3..29a275d8f 100644 --- a/pkg/utils/utils.go +++ b/pkg/utils/utils.go @@ -2,10 +2,14 @@ package utils import ( "context" + "encoding/base64" "math" "math/big" mrand "math/rand" + "reflect" "time" + + "github.com/smartcontractkit/chainlink-relay/pkg/types" ) // WithJitter adds +/- 10% to a duration @@ -59,3 +63,84 @@ func FitsInNBitsSigned(n int, bi *big.Int) bool { } return bi.BitLen() <= n-1 } + +func MergeValueFields(valueFields []map[string]any) (map[string]any, error) { + numItems := len(valueFields) + + switch numItems { + case 0: + return map[string]any{}, nil + default: + mergedReflect := map[string]reflect.Value{} + for k, v := range valueFields[0] { + rv := reflect.ValueOf(v) + slice := reflect.MakeSlice(reflect.SliceOf(rv.Type()), numItems, numItems) + slice.Index(0).Set(rv) + mergedReflect[k] = slice + } + + for i, valueField := range valueFields[1:] { + if len(valueField) != len(mergedReflect) { + return nil, types.InvalidTypeError{} + } + + for k, slice := range mergedReflect { + if value, ok := valueField[k]; ok { + sliceElm := slice.Index(i + 1) + rv := reflect.ValueOf(value) + if !rv.Type().AssignableTo(sliceElm.Type()) { + return nil, types.InvalidTypeError{} + } + sliceElm.Set(rv) + } else { + return nil, types.InvalidTypeError{} + } + } + } + + merged := map[string]any{} + + for k, v := range mergedReflect { + merged[k] = v.Interface() + } + + return merged, nil + } +} + +func SplitValueFields(decoded map[string]any) ([]map[string]any, error) { + var result []map[string]any + + for k, v := range decoded { + iv := reflect.ValueOf(v) + kind := iv.Kind() + if kind != reflect.Slice && kind != reflect.Array { + if kind != reflect.String { + return nil, types.NotASliceError{} + } + rawBytes, err := base64.StdEncoding.DecodeString(v.(string)) + if err != nil { + return nil, types.InvalidTypeError{} + } + iv = reflect.ValueOf(rawBytes) + } + + length := iv.Len() + if result == nil { + result = make([]map[string]any, length) + for i := 0; i < length; i++ { + result[i] = map[string]any{} + } + } + + if len(result) != length { + return nil, types.InvalidTypeError{} + } + + for i := 0; i < length; i++ { + result[i][k] = iv.Index(i).Interface() + } + } + + return result, nil +} diff --git a/pkg/utils/utils_test.go b/pkg/utils/utils_test.go index 678f91192..ee2f741e5 100644 --- a/pkg/utils/utils_test.go +++ b/pkg/utils/utils_test.go @@ -5,8 +5,11 @@ import ( "math/big" "testing" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/assert" + "github.com/smartcontractkit/chainlink-relay/pkg/types" "github.com/smartcontractkit/chainlink-relay/pkg/utils" ) @@ -27,3 +30,87 @@ func TestFitsInNBitsSigned(t *testing.T) { assert.False(t, utils.FitsInNBitsSigned(16, bi)) }) } + +func TestMergeValueFields(t *testing.T) { + t.Parallel() + t.Run("Merges fields", func(t *testing.T) { + input := []map[string]any{ + {"Foo": int32(1), "Bar": "Hi"}, + {"Foo": int32(2), "Bar": "How"}, + {"Foo": int32(3), "Bar": "Are"}, + {"Foo": int32(4), "Bar": "You?"}, + } + + output, err := utils.MergeValueFields(input) + require.NoError(t, err) + + expected := map[string]any{ + "Foo": []int32{1, 2, 3, 4}, + "Bar": []string{"Hi", "How", "Are", "You?"}, + } + assert.Equal(t, expected, output) + }) + + t.Run("Returns error if keys are not the same", func(t *testing.T) { + input := []map[string]any{ + {"Foo": int32(1), "Bar": "Hi"}, + {"Zap": 2, "Foo": int32(2), "Bar": "How"}, + } + + _, err := utils.MergeValueFields(input) + + assert.IsType(t, types.InvalidTypeError{}, err) + }) + + t.Run("Returns error if values are not compatible types", func(t *testing.T) { + input := []map[string]any{ + {"Foo": int32(1), "Bar": "Hi"}, + {"Foo": int32(2), "Bar": int32(3)}, + } + + _, err := utils.MergeValueFields(input) + + assert.IsType(t, types.InvalidTypeError{}, err) + }) +} + +func TestSplitValueField(t *testing.T) { + t.Parallel() + t.Run("Returns slit field values", func(t *testing.T) { + input := map[string]any{ + "Foo": []int32{1, 2, 3, 4}, + "Bar": [4]string{"Hi", "How", "Are", "You?"}, + } + + output, err := utils.SplitValueFields(input) + require.NoError(t, err) + + expected := []map[string]any{ + {"Foo": int32(1), "Bar": "Hi"}, + {"Foo": int32(2), "Bar": "How"}, + {"Foo": int32(3), "Bar": "Are"}, + {"Foo": int32(4), "Bar": "You?"}, + } + assert.Equal(t, expected, output) + }) + + t.Run("Returns error if lengths do not match", func(t *testing.T) { + input := map[string]any{ + "Foo": []int32{1, 2, 3}, + "Bar": []string{"Hi", "How", "Are", "You?"}, + } + + _, err := utils.SplitValueFields(input) + assert.IsType(t, types.InvalidTypeError{}, err) + }) + + t.Run("Returns error if item is not an array or slice", func(t *testing.T) { + input := map[string]any{ + "Foo": int32(3), + "Bar": []string{"Hi", "How", "Are", "You?"}, + } + + _, err := utils.SplitValueFields(input) + assert.IsType(t, types.NotASliceError{}, err) + }) +}