Skip to content

Commit

Permalink
feat: rapidproto generator support for gogo unmarshalling (#99)
Browse files Browse the repository at this point in the history
  • Loading branch information
kocubinski authored Feb 21, 2023
1 parent e7b0579 commit 0d47e93
Show file tree
Hide file tree
Showing 12 changed files with 108 additions and 31 deletions.
125 changes: 101 additions & 24 deletions rapidproto/rapidproto.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package rapidproto

import (
"fmt"
"math"

cosmos_proto "github.com/cosmos/cosmos-proto"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
Expand All @@ -15,20 +17,60 @@ func MessageGenerator[T proto.Message](x T, options GeneratorOptions) *rapid.Gen
return rapid.Custom(func(t *rapid.T) T {
msg := msgType.New()

options.setFields(t, msg, 0)
options.setFields(t, nil, msg, 0)

return msg.Interface().(T)
})
}

// FieldMapper is a function that can be used to override the default behavior of the generator for a specific field.
// The first argument is the rapid.T, the second is the field descriptor, and the third is the field name.
// If the function returns nil, the default behavior will be used.
type FieldMapper func(*rapid.T, protoreflect.FieldDescriptor, string) (protoreflect.Value, bool)

type GeneratorOptions struct {
AnyTypeURLs []string
Resolver protoregistry.MessageTypeResolver
AnyTypeURLs []string
InterfaceHints map[string]string
Resolver protoregistry.MessageTypeResolver

// NoEmptyLists will cause the generator to not generate empty lists
// Recall that an empty list will marshal (and unmarshal) to null. Some encodings may treat these states
// differently. For example, in JSON, an empty list is encoded as [], while null is encoded as null.
NoEmptyLists bool

// DisallowNilMessages will cause the generator to not generate nil messages to protoreflect.MessageKind fields
DisallowNilMessages bool

// FieldMaps is a list of FieldMapper functions that can be used to override the default behavior of the generator
// for a specific field.
FieldMaps []FieldMapper
}

const depthLimit = 10

func (opts GeneratorOptions) setFields(t *rapid.T, msg protoreflect.Message, depth int) bool {
func (opts GeneratorOptions) WithAnyTypes(anyTypes ...proto.Message) GeneratorOptions {
for _, a := range anyTypes {
opts.AnyTypeURLs = append(opts.AnyTypeURLs, fmt.Sprintf("/%s", a.ProtoReflect().Descriptor().FullName()))
}
return opts
}

func (opts GeneratorOptions) WithDisallowNil() GeneratorOptions {
o := &opts
o.DisallowNilMessages = true
return *o
}

func (opts GeneratorOptions) WithInterfaceHint(i string, impl proto.Message) GeneratorOptions {
if opts.InterfaceHints == nil {
opts.InterfaceHints = make(map[string]string)
}
opts.InterfaceHints[i] = string(impl.ProtoReflect().Descriptor().FullName())
return opts
}

func (opts GeneratorOptions) setFields(
t *rapid.T, field protoreflect.FieldDescriptor, msg protoreflect.Message, depth int) bool {
// to avoid stack overflow we limit the depth of nested messages
if depth > depthLimit {
return false
Expand All @@ -44,20 +86,23 @@ func (opts GeneratorOptions) setFields(t *rapid.T, msg protoreflect.Message, dep
opts.genDuration(t, msg)
return true
case anyFullName:
return opts.genAny(t, msg, depth)
opts.genAny(t, field, msg, depth)
return true
case fieldMaskFullName:
opts.genFieldMask(t, msg)
return true
default:
fields := descriptor.Fields()
n := fields.Len()
for i := 0; i < n; i++ {
field := fields.Get(i)
if !rapid.Bool().Draw(t, fmt.Sprintf("gen-%s", field.Name())) {
continue
f := fields.Get(i)
if !rapid.Bool().Draw(t, fmt.Sprintf("gen-%s", f.Name())) {
if (f.Kind() == protoreflect.MessageKind) && !opts.DisallowNilMessages {
continue
}
}

opts.setFieldValue(t, msg, field, depth)
opts.setFieldValue(t, msg, f, depth)
}
return true
}
Expand All @@ -77,10 +122,14 @@ func (opts GeneratorOptions) setFieldValue(t *rapid.T, msg protoreflect.Message,
switch {
case field.IsList():
list := msg.Mutable(field).List()
n := rapid.IntRange(0, 10).Draw(t, fmt.Sprintf("%sN", name))
min := 0
if opts.NoEmptyLists {
min = 1
}
n := rapid.IntRange(min, 10).Draw(t, fmt.Sprintf("%sN", name))
for i := 0; i < n; i++ {
if kind == protoreflect.MessageKind || kind == protoreflect.GroupKind {
if !opts.setFields(t, list.AppendMutable().Message(), depth+1) {
if !opts.setFields(t, field, list.AppendMutable().Message(), depth+1) {
list.Truncate(i)
}
} else {
Expand All @@ -96,26 +145,39 @@ func (opts GeneratorOptions) setFieldValue(t *rapid.T, msg protoreflect.Message,
valueKind := valueField.Kind()
key := opts.genScalarFieldValue(t, keyField, fmt.Sprintf("%s%d-key", name, i))
if valueKind == protoreflect.MessageKind || valueKind == protoreflect.GroupKind {
if !opts.setFields(t, m.Mutable(key.MapKey()).Message(), depth+1) {
if !opts.setFields(t, field, m.Mutable(key.MapKey()).Message(), depth+1) {
m.Clear(key.MapKey())
}
} else {
value := opts.genScalarFieldValue(t, valueField, fmt.Sprintf("%s%d-key", name, i))
m.Set(key.MapKey(), value)
}
}
default:
if kind == protoreflect.MessageKind || kind == protoreflect.GroupKind {
if !opts.setFields(t, msg.Mutable(field).Message(), depth+1) {
case kind == protoreflect.MessageKind:
mutableField := msg.Mutable(field)
if mutableField.Message().Descriptor().FullName() == anyFullName {
if !opts.genAny(t, field, mutableField.Message(), depth+1) {
msg.Clear(field)
}
} else {
msg.Set(field, opts.genScalarFieldValue(t, field, name))
} else if !opts.setFields(t, field, mutableField.Message(), depth+1) {
msg.Clear(field)
}
case kind == protoreflect.GroupKind:
if !opts.setFields(t, field, msg.Mutable(field).Message(), depth+1) {
msg.Clear(field)
}
default:
msg.Set(field, opts.genScalarFieldValue(t, field, name))
}
}

func (opts GeneratorOptions) genScalarFieldValue(t *rapid.T, field protoreflect.FieldDescriptor, name string) protoreflect.Value {
for _, fm := range opts.FieldMaps {
if v, ok := fm(t, field, name); ok {
return v
}
}

switch field.Kind() {
case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
return protoreflect.ValueOfInt32(rapid.Int32().Draw(t, name))
Expand Down Expand Up @@ -146,8 +208,11 @@ func (opts GeneratorOptions) genScalarFieldValue(t *rapid.T, field protoreflect.
}

const (
secondsName = "seconds"
nanosName = "nanos"
// MaxDurationSeconds the maximum number of seconds (when expressed as nanoseconds) which can fit in an int64.
// gogoproto encodes google.protobuf.Duration as a time.Duration, which is 64-bit signed integer.
MaxDurationSeconds = int64(math.MaxInt64/int(1e9)) - 1
secondsName = "seconds"
nanosName = "nanos"
)

func (opts GeneratorOptions) genTimestamp(t *rapid.T, msg protoreflect.Message) {
Expand All @@ -157,7 +222,7 @@ func (opts GeneratorOptions) genTimestamp(t *rapid.T, msg protoreflect.Message)
}

func (opts GeneratorOptions) genDuration(t *rapid.T, msg protoreflect.Message) {
seconds := rapid.Int64Range(0, 315576000000).Draw(t, "seconds")
seconds := rapid.Int64Range(0, int64(MaxDurationSeconds)).Draw(t, "seconds")
nanos := rapid.Int32Range(0, 999999999).Draw(t, "nanos")
setSecondsNanosFields(t, msg, seconds, nanos)
}
Expand All @@ -179,23 +244,35 @@ const (
valueName = "value"
)

func (opts GeneratorOptions) genAny(t *rapid.T, msg protoreflect.Message, depth int) bool {
func (opts GeneratorOptions) genAny(
t *rapid.T, field protoreflect.FieldDescriptor, msg protoreflect.Message, depth int) bool {
if len(opts.AnyTypeURLs) == 0 {
return false
}

fields := msg.Descriptor().Fields()
var typeURL string
fopts := field.Options()
if proto.HasExtension(fopts, cosmos_proto.E_AcceptsInterface) {
ai := proto.GetExtension(fopts, cosmos_proto.E_AcceptsInterface).(string)
if impl, found := opts.InterfaceHints[ai]; found {
typeURL = fmt.Sprintf("/%s", impl)
} else {
panic(fmt.Sprintf("no implementation found for interface %s", ai))
}
} else {
typeURL = rapid.SampledFrom(opts.AnyTypeURLs).Draw(t, "type_url")
}

typeURL := rapid.SampledFrom(opts.AnyTypeURLs).Draw(t, "type_url")
typ, err := opts.Resolver.FindMessageByURL(typeURL)
assert.NilError(t, err)
fields := msg.Descriptor().Fields()

typeURLField := fields.ByName(typeURLName)
assert.Assert(t, typeURLField != nil)
msg.Set(typeURLField, protoreflect.ValueOfString(typeURL))

valueMsg := typ.New()
opts.setFields(t, valueMsg, depth+1)
opts.setFields(t, nil, valueMsg, depth+1)
valueBz, err := proto.Marshal(valueMsg.Interface())
assert.NilError(t, err)

Expand Down
4 changes: 2 additions & 2 deletions rapidproto/rapidproto_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ import (
"fmt"
"testing"

"github.com/cosmos/cosmos-proto/rapidproto"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"gotest.tools/v3/assert"
"gotest.tools/v3/golden"
"pgregory.net/rapid"

"github.com/cosmos/cosmos-proto/rapidproto"
"github.com/cosmos/cosmos-proto/testpb"
)

Expand All @@ -19,7 +19,7 @@ import (
// to generally look good.
func TestRegression(t *testing.T) {
gen := rapidproto.MessageGenerator(&testpb.A{}, rapidproto.GeneratorOptions{})
for i := 0; i < 5; i++ {
for i := 1000; i < 1005; i++ {
testRegressionSeed(t, i, gen)
}
}
Expand Down
1 change: 0 additions & 1 deletion rapidproto/testdata/seed0.json

This file was deleted.

1 change: 0 additions & 1 deletion rapidproto/testdata/seed1.json

This file was deleted.

1 change: 1 addition & 0 deletions rapidproto/testdata/seed1000.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"someBoolean":true, "INT32":112, "SINT32":20861, "UINT32":39, "INT64":"1916", "SING64":"-208", "UINT64":"95", "SFIXED32":-4267293, "FIXED32":56, "FLOAT":-0.061933517, "SFIXED64":"-128961679122", "FIXED64":"56609281", "DOUBLE":-4.2137932902833315e+218, "STRING":"a", "BYTES":"EQzPLBA1mwUBCQ==", "MAP":{"ž󠀶×a)?c":{"x":"Ʉ⃞➋-"}}, "ONEOFSTRING":"Aা𑅅҉^*[ॎ{؂Ⱥ\n@^<n", "type":"܏󠀫?“@\u0003~—!+‮#"}
1 change: 1 addition & 0 deletions rapidproto/testdata/seed1001.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"someBoolean":true, "INT32":-3, "SINT32":-15161, "UINT32":22975, "INT64":"-10183", "SING64":"1", "UINT64":"118", "SFIXED32":283, "FIXED32":4294967295, "FLOAT":55.4375, "SFIXED64":"1", "FIXED64":"122", "DOUBLE":2.255889892109379e-30, "STRING":"\\\u0000", "BYTES":"Dqs=", "MESSAGE":{"x":"ꦽ:a Ⱥ"}, "MAP":{"\u0001/?-?aº":{"x":"٤"}, "!<ाAa\t":{"x":""}, "$₈?a㱰ᛯ\u0000*ઃ":{"x":"􊬊݉?"}, "-\u0002𑴲፭":{"x":" ܏A!Ⱥ𐓃ꧦ\u001bƻি\țH꣕~ؔ𑱞"}}, "LIST":[{"x":"#aª*@𖽠"}, {"x":"A!\u001a??'[כ"}, {"x":"+𝍧~!\n~"}, {"x":"!&= ±bAᛮ؎°׆?Ֆ"}, {"x":"?*@ᾝ "}, {"x":"%𫑙%`푸𝕘𐡹҉܃"}, {"x":"7A𓐸₥?%"}, {"x":"\u0007a/ॉa"}, {"x":"$^A‪Ⱥ$"}, {"x":" B҈\u001b"}], "ONEOFSTRING":"&ϔ/`", "LISTENUM":["Two", "One", "Two", "Two", "Two"], "type":"҉𓐴A"}
1 change: 1 addition & 0 deletions rapidproto/testdata/seed1002.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"enum":"Two", "INT32":-2718, "SINT32":-60035, "UINT32":2601799, "INT64":"-59", "SING64":"-2", "UINT64":"80", "FIXED32":1, "FLOAT":-0.08959961, "SFIXED64":"-2047", "FIXED64":"175", "DOUBLE":3.14684672949901e-85, "BYTES":"MB8=", "MESSAGE":{}, "ONEOFSTRING":"ⅾaՅ", "imported":{}}
1 change: 1 addition & 0 deletions rapidproto/testdata/seed1003.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"someBoolean":true, "INT32":-59, "SINT32":3, "UINT32":390, "INT64":"-2", "SING64":"-8", "SFIXED32":1171755373, "FIXED32":747, "FLOAT":-0.0033711107, "SFIXED64":"1", "FIXED64":"75", "DOUBLE":3.5, "LIST":[{"x":"{𑤳\""}, {}, {"x":"⃤¹⃢௹~‮±⪏"}, {"x":"*"}, {"x":"؜~"}, {"x":"~"}, {}, {"x":"AȺȺ᮵?֎៰"}], "ONEOFSTRING":"", "LISTENUM":["One", "Two", "Two"], "imported":{}, "type":"ʱ\u000b󠁇"}
1 change: 1 addition & 0 deletions rapidproto/testdata/seed1004.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"INT32":-201, "SINT32":1, "UINT32":5, "INT64":"3", "SING64":"12", "UINT64":"2860198", "SFIXED32":3, "FIXED32":1597511853, "FLOAT":0.0002543144, "SFIXED64":"-33339063266", "FIXED64":"205355", "DOUBLE":-4.876529669536761e+251, "STRING":"ˌᾋ˧#₨뗭ॱ⨲0_-‮⓹+A~\u000b!ᤳ𐄘%w₥\u0000", "BYTES":"BP+NAQ==", "ONEOFSTRING":"ả₿⸟؀1", "LISTENUM":["Two", "One", "One", "Two", "Two", "One", "One"], "imported":{}, "type":"@@\\"}
1 change: 0 additions & 1 deletion rapidproto/testdata/seed2.json

This file was deleted.

1 change: 0 additions & 1 deletion rapidproto/testdata/seed3.json

This file was deleted.

1 change: 0 additions & 1 deletion rapidproto/testdata/seed4.json

This file was deleted.

0 comments on commit 0d47e93

Please sign in to comment.