diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go index c2bb12ba32..29a275d8f3 100644 --- a/pkg/utils/utils.go +++ b/pkg/utils/utils.go @@ -2,10 +2,14 @@ package utils import ( "context" + "encoding/base64" "math" "math/big" mrand "math/rand" + "reflect" "time" + + "github.com/smartcontractkit/chainlink-relay/pkg/types" ) // WithJitter adds +/- 10% to a duration @@ -59,3 +63,84 @@ func FitsInNBitsSigned(n int, bi *big.Int) bool { } return bi.BitLen() <= n-1 } + +func MergeValueFields(valueFields []map[string]any) (map[string]any, error) { + numItems := len(valueFields) + + switch numItems { + case 0: + return map[string]any{}, nil + default: + mergedReflect := map[string]reflect.Value{} + for k, v := range valueFields[0] { + rv := reflect.ValueOf(v) + slice := reflect.MakeSlice(reflect.SliceOf(rv.Type()), numItems, numItems) + slice.Index(0).Set(rv) + mergedReflect[k] = slice + } + + for i, valueField := range valueFields[1:] { + if len(valueField) != len(mergedReflect) { + return nil, types.InvalidTypeError{} + } + + for k, slice := range mergedReflect { + if value, ok := valueField[k]; ok { + sliceElm := slice.Index(i + 1) + rv := reflect.ValueOf(value) + if !rv.Type().AssignableTo(sliceElm.Type()) { + return nil, types.InvalidTypeError{} + } + sliceElm.Set(rv) + } else { + return nil, types.InvalidTypeError{} + } + } + } + + merged := map[string]any{} + + for k, v := range mergedReflect { + merged[k] = v.Interface() + } + + return merged, nil + } +} + +func SplitValueFields(decoded map[string]any) ([]map[string]any, error) { + var result []map[string]any + + for k, v := range decoded { + iv := reflect.ValueOf(v) + kind := iv.Kind() + if kind != reflect.Slice && kind != reflect.Array { + if kind != reflect.String { + return nil, types.NotASliceError{} + } + rawBytes, err := base64.StdEncoding.DecodeString(v.(string)) + if err != nil { + return nil, types.InvalidTypeError{} + } + iv = reflect.ValueOf(rawBytes) + } + + length := iv.Len() + if result == nil { + result = make([]map[string]any, length) + for i := 0; i < length; i++ { + result[i] = map[string]any{} + } + } + + if len(result) != length { + return nil, types.InvalidTypeError{} + } + + for i := 0; i < length; i++ { + result[i][k] = iv.Index(i).Interface() + } + } + + return result, nil +} diff --git a/pkg/utils/utils_test.go b/pkg/utils/utils_test.go index 678f91192b..189a9a7441 100644 --- a/pkg/utils/utils_test.go +++ b/pkg/utils/utils_test.go @@ -1,12 +1,14 @@ package utils_test import ( + "github.com/stretchr/testify/require" "math" "math/big" "testing" "github.com/stretchr/testify/assert" + "github.com/smartcontractkit/chainlink-relay/pkg/types" "github.com/smartcontractkit/chainlink-relay/pkg/utils" ) @@ -27,3 +29,87 @@ func TestFitsInNBitsSigned(t *testing.T) { assert.False(t, utils.FitsInNBitsSigned(16, bi)) }) } + +func TestMergeValueFields(t *testing.T) { + t.Parallel() + t.Run("Merges fields", func(t *testing.T) { + input := []map[string]any{ + {"Foo": int32(1), "Bar": "Hi"}, + {"Foo": int32(2), "Bar": "How"}, + {"Foo": int32(3), "Bar": "Are"}, + {"Foo": int32(4), "Bar": "You?"}, + } + + output, err := utils.MergeValueFields(input) + require.NoError(t, err) + + expected := map[string]any{ + "Foo": []int32{1, 2, 3, 4}, + "Bar": []string{"Hi", "How", "Are", "You?"}, + } + assert.Equal(t, expected, output) + }) + + t.Run("Returns error if keys are not the same", func(t *testing.T) { + input := []map[string]any{ + {"Foo": int32(1), "Bar": "Hi"}, + {"Zap": 2, "Foo": int32(2), "Bar": "How"}, + } + + _, err := utils.MergeValueFields(input) + + assert.IsType(t, types.InvalidTypeError{}, err) + }) + + t.Run("Returns error if values are not compatible types", func(t *testing.T) { + input := []map[string]any{ + {"Foo": int32(1), "Bar": "Hi"}, + {"Foo": int32(2), "Bar": int32(3)}, + } + + _, err := utils.MergeValueFields(input) + + assert.IsType(t, types.InvalidTypeError{}, err) + }) +} + +func TestSplitValueField(t *testing.T) { + t.Parallel() + t.Run("Returns slit field values", func(t *testing.T) { + input := map[string]any{ + "Foo": []int32{1, 2, 3, 4}, + "Bar": [4]string{"Hi", "How", "Are", "You?"}, + } + + output, err := utils.SplitValueFields(input) + require.NoError(t, err) + + expected := []map[string]any{ + {"Foo": int32(1), "Bar": "Hi"}, + {"Foo": int32(2), "Bar": "How"}, + {"Foo": int32(3), "Bar": "Are"}, + {"Foo": int32(4), "Bar": "You?"}, + } + assert.Equal(t, expected, output) + }) + + t.Run("Returns error if lengths do not match", func(t *testing.T) { + input := map[string]any{ + "Foo": []int32{1, 2, 3}, + "Bar": []string{"Hi", "How", "Are", "You?"}, + } + + _, err := utils.SplitValueFields(input) + assert.IsType(t, types.InvalidTypeError{}, err) + }) + + t.Run("Returns error if item is not an array or slice", func(t *testing.T) { + input := map[string]any{ + "Foo": int32(3), + "Bar": []string{"Hi", "How", "Are", "You?"}, + } + + _, err := utils.SplitValueFields(input) + assert.IsType(t, types.NotASliceError{}, err) + }) +}