Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reflect Marshaler #1592

Merged
merged 15 commits into from
Oct 10, 2024
143 changes: 77 additions & 66 deletions abi/dynamic/reflect_marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@ import (
"encoding/json"
"fmt"
"reflect"
"regexp"
"strconv"
"strings"

"github.com/ava-labs/avalanchego/utils/wrappers"
"golang.org/x/text/cases"
"golang.org/x/text/language"

Expand All @@ -18,58 +20,62 @@ import (
"github.com/ava-labs/hypersdk/consts"
)

func DynamicMarshal(inputAbi abi.ABI, typeName string, jsonData string) ([]byte, error) {
// Find the type in the ABI
abiType := findABIType(inputAbi, typeName)
if abiType == nil {
// Matches fixed-size arrays like [32]uint8
var fixedSizeArrayRegex = regexp.MustCompile(`^\[(\d+)\](.+)$`)

func Marshal(inputABI abi.ABI, typeName string, jsonData string) ([]byte, error) {
_, ok := findABIType(inputABI, typeName)
if !ok {
return nil, fmt.Errorf("type %s not found in ABI", typeName)
}

// Create a cache to avoid rebuilding types
typeCache := make(map[string]reflect.Type)

// Create a dynamic struct type
dynamicType := getReflectType(typeName, inputAbi, typeCache)
typ, err := getReflectType(typeName, inputABI, typeCache)
if err != nil {
return nil, fmt.Errorf("failed to get reflect type: %w", err)
}

// Create an instance of the dynamic struct
dynamicValue := reflect.New(dynamicType).Interface()
value := reflect.New(typ).Interface()

// Unmarshal JSON data into the dynamic struct
if err := json.Unmarshal([]byte(jsonData), dynamicValue); err != nil {
err = json.Unmarshal([]byte(jsonData), value)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal JSON data: %w", err)
}

// Marshal the dynamic struct using LinearCodec
writer := codec.NewWriter(0, consts.NetworkSizeLimit)
if err := codec.LinearCodec.MarshalInto(dynamicValue, writer.Packer); err != nil {
err = codec.LinearCodec.MarshalInto(value, writer.Packer)
if err != nil {
return nil, fmt.Errorf("failed to marshal struct: %w", err)
}

return writer.Bytes(), nil
}

func DynamicUnmarshal(inputAbi abi.ABI, typeName string, data []byte) (string, error) {
// Find the type in the ABI
abiType := findABIType(inputAbi, typeName)
if abiType == nil {
func Unmarshal(inputABI abi.ABI, typeName string, data []byte) (string, error) {
_, ok := findABIType(inputABI, typeName)
if !ok {
return "", fmt.Errorf("type %s not found in ABI", typeName)
}

// Create a cache to avoid rebuilding types
typeCache := make(map[string]reflect.Type)

// Create a dynamic struct type
dynamicType := getReflectType(typeName, inputAbi, typeCache)
dynamicType, err := getReflectType(typeName, inputABI, typeCache)
if err != nil {
return "", fmt.Errorf("failed to get reflect type: %w", err)
}

// Create an instance of the dynamic struct
dynamicValue := reflect.New(dynamicType).Interface()

// Unmarshal the data into the dynamic struct
if err := codec.LinearCodec.Unmarshal(data, dynamicValue); err != nil {
packer := wrappers.Packer{
Bytes: data,
MaxSize: consts.NetworkSizeLimit,
}
err = codec.LinearCodec.UnmarshalFrom(&packer, dynamicValue)
if err != nil {
return "", fmt.Errorf("failed to unmarshal data: %w", err)
}

// Marshal the dynamic struct back to JSON
jsonData, err := json.Marshal(dynamicValue)
if err != nil {
return "", fmt.Errorf("failed to marshal struct to JSON: %w", err)
Expand All @@ -78,85 +84,90 @@ func DynamicUnmarshal(inputAbi abi.ABI, typeName string, data []byte) (string, e
return string(jsonData), nil
}

func getReflectType(abiTypeName string, inputAbi abi.ABI, typeCache map[string]reflect.Type) reflect.Type {
func getReflectType(abiTypeName string, inputABI abi.ABI, typeCache map[string]reflect.Type) (reflect.Type, error) {
switch abiTypeName {
case "string":
return reflect.TypeOf("")
return reflect.TypeOf(""), nil
case "uint8":
return reflect.TypeOf(uint8(0))
return reflect.TypeOf(uint8(0)), nil
case "uint16":
return reflect.TypeOf(uint16(0))
return reflect.TypeOf(uint16(0)), nil
case "uint32":
return reflect.TypeOf(uint32(0))
return reflect.TypeOf(uint32(0)), nil
case "uint64":
return reflect.TypeOf(uint64(0))
return reflect.TypeOf(uint64(0)), nil
case "int8":
return reflect.TypeOf(int8(0))
return reflect.TypeOf(int8(0)), nil
case "int16":
return reflect.TypeOf(int16(0))
return reflect.TypeOf(int16(0)), nil
case "int32":
return reflect.TypeOf(int32(0))
return reflect.TypeOf(int32(0)), nil
case "int64":
return reflect.TypeOf(int64(0))
return reflect.TypeOf(int64(0)), nil
case "Address":
return reflect.TypeOf(codec.Address{})
return reflect.TypeOf(codec.Address{}), nil
default:
// golang slices
if strings.HasPrefix(abiTypeName, "[]") {
elemType := getReflectType(strings.TrimPrefix(abiTypeName, "[]"), inputAbi, typeCache)
return reflect.SliceOf(elemType)
} else if strings.HasPrefix(abiTypeName, "[") {
// Handle fixed-size arrays

sizeStr := strings.Split(abiTypeName, "]")[0]
sizeStr = strings.TrimPrefix(sizeStr, "[")
elemType, err := getReflectType(strings.TrimPrefix(abiTypeName, "[]"), inputABI, typeCache)
if err != nil {
return nil, err
}
return reflect.SliceOf(elemType), nil
}

// golang arrays
match := fixedSizeArrayRegex.FindStringSubmatch(abiTypeName) // ^\[(\d+)\](.+)$
if match != nil {
sizeStr := match[1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This variable is unnecessary as it's only used on the next line and it doesn't add any additional context beyond what size itself does.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's exclusively for readability

size, err := strconv.Atoi(sizeStr)
if err != nil {
return reflect.TypeOf((*interface{})(nil)).Elem()
return nil, fmt.Errorf("failed to convert size to int: %w", err)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using %w makes the error part of your public API, which I don't think is the desired behaviour here. %v is more appropriate IMO.

}
elemType := getReflectType(strings.TrimPrefix(abiTypeName, "["+sizeStr+"]"), inputAbi, typeCache)
return reflect.ArrayOf(size, elemType)
elemType, err := getReflectType(match[2], inputABI, typeCache)
if err != nil {
return nil, err
}
return reflect.ArrayOf(size, elemType), nil
}
// For custom types, recursively construct the struct type

// Check if type already in cache
if cachedType, ok := typeCache[abiTypeName]; ok {
return cachedType
// For custom types, recursively construct the struct type
cachedType, ok := typeCache[abiTypeName]
if ok {
return cachedType, nil
}

// Find the type in the ABI
abiType := findABIType(inputAbi, abiTypeName)
if abiType == nil {
// If not found, fallback to interface{}
return reflect.TypeOf((*interface{})(nil)).Elem()
abiType, ok := findABIType(inputABI, abiTypeName)
if !ok {
return nil, fmt.Errorf("type %s not found in ABI", abiTypeName)
}

// Build fields
// It is a struct, as we don't support anything else as custom types
fields := make([]reflect.StructField, len(abiType.Fields))
for i, field := range abiType.Fields {
fieldType := getReflectType(field.Type, inputAbi, typeCache)
fieldType, err := getReflectType(field.Type, inputABI, typeCache)
if err != nil {
return nil, err
}
fields[i] = reflect.StructField{
Name: cases.Title(language.English).String(field.Name),
Type: fieldType,
Tag: reflect.StructTag(fmt.Sprintf(`serialize:"true" json:"%s"`, field.Name)),
}
}
// Create struct type
structType := reflect.StructOf(fields)

// Cache the type
structType := reflect.StructOf(fields)
typeCache[abiTypeName] = structType

return structType
return structType, nil
}
}

// Helper function to find ABI type
func findABIType(inputAbi abi.ABI, typeName string) *abi.Type {
for i := range inputAbi.Types {
if inputAbi.Types[i].Name == typeName {
return &inputAbi.Types[i]
func findABIType(inputABI abi.ABI, typeName string) (abi.Type, bool) {
for _, typ := range inputABI.Types {
if typ.Name == typeName {
return typ, true
}
}
return nil
return abi.Type{}, false
}
32 changes: 27 additions & 5 deletions abi/dynamic/reflect_marshal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
func TestDynamicMarshal(t *testing.T) {
require := require.New(t)

// Load the ABI
abiJSON := mustReadFile(t, "../testdata/abi.json")
var abi abi.ABI

err := json.Unmarshal(abiJSON, &abi)
require.NoError(err)

Expand Down Expand Up @@ -48,17 +48,15 @@
// Read the JSON data
jsonData := mustReadFile(t, "../testdata/"+tc.name+".json")

// Use DynamicMarshal to marshal the data
objectBytes, err := DynamicMarshal(abi, tc.typeName, string(jsonData))
objectBytes, err := Marshal(abi, tc.typeName, string(jsonData))
require.NoError(err)

// Compare with expected hex
expectedHex := string(mustReadFile(t, "../testdata/"+tc.name+".hex"))
expectedHex = strings.TrimSpace(expectedHex)
require.Equal(expectedHex, hex.EncodeToString(objectBytes))

// Use DynamicUnmarshal to unmarshal the data
unmarshaledJSON, err := DynamicUnmarshal(abi, tc.typeName, objectBytes)
unmarshaledJSON, err := Unmarshal(abi, tc.typeName, objectBytes)
require.NoError(err)

// Compare with expected JSON
Expand All @@ -67,6 +65,30 @@
}
}

func TestDynamicMarshalErrors(t *testing.T) {
require := require.New(t)

abiJSON := mustReadFile(t, "../testdata/abi.json")
var abi abi.ABI

err := json.Unmarshal(abiJSON, &abi)
require.NoError(err)

t.Run("malformed JSON", func(t *testing.T) {

Check failure on line 77 in abi/dynamic/reflect_marshal_test.go

View workflow job for this annotation

GitHub Actions / hypersdk-lint

unused-parameter: parameter 't' seems to be unused, consider removing or renaming it as _ (revive)
malformedJSON := `{"uint8": 42, "uint16": 1000, "uint32": 100000, "uint64": 10000000000, "int8": -42, "int16": -1000, "int32": -100000, "int64": -10000000000,`
_, err := Marshal(abi, "MockObjectAllNumbers", malformedJSON)
require.Error(err)

Check failure on line 80 in abi/dynamic/reflect_marshal_test.go

View workflow job for this annotation

GitHub Actions / hypersdk-lint

use of `require.Error` forbidden because "ErrorIs should be used instead" (forbidigo)
require.Contains(err.Error(), "unexpected end of JSON input")
})

t.Run("wrong struct name", func(t *testing.T) {
jsonData := mustReadFile(t, "../testdata/numbers.json")
_, err := Marshal(abi, "NonExistentObject", string(jsonData))
require.Error(err)

Check failure on line 87 in abi/dynamic/reflect_marshal_test.go

View workflow job for this annotation

GitHub Actions / hypersdk-lint

use of `require.Error` forbidden because "ErrorIs should be used instead" (forbidigo)
require.Contains(err.Error(), "type NonExistentObject not found in ABI")
})
}

func mustReadFile(t *testing.T, path string) []byte {
t.Helper()

Expand Down
8 changes: 4 additions & 4 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ require (
go.uber.org/zap v1.26.0
golang.org/x/crypto v0.22.0
golang.org/x/exp v0.0.0-20231127185646-65229373498e
golang.org/x/sync v0.6.0
golang.org/x/sync v0.7.0
golang.org/x/text v0.14.0
google.golang.org/grpc v1.62.0
google.golang.org/protobuf v1.34.2
Expand Down Expand Up @@ -141,9 +141,9 @@ require (
go.opentelemetry.io/proto/otlp v1.0.0 // indirect
go.uber.org/mock v0.4.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/net v0.23.0 // indirect
golang.org/x/sys v0.18.0 // indirect
golang.org/x/term v0.18.0 // indirect
golang.org/x/net v0.24.0 // indirect
golang.org/x/sys v0.19.0 // indirect
golang.org/x/term v0.19.0 // indirect
golang.org/x/time v0.3.0 // indirect
golang.org/x/tools v0.17.0 // indirect
gonum.org/v1/gonum v0.11.0 // indirect
Expand Down
Loading