diff --git a/rapidproto/rapidproto.go b/rapidproto/rapidproto.go index 2922cbc..4915031 100644 --- a/rapidproto/rapidproto.go +++ b/rapidproto/rapidproto.go @@ -2,7 +2,9 @@ package rapidproto import ( "fmt" + "math" + cosmos_proto "github.com/cosmos/cosmos-proto" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/reflect/protoregistry" @@ -15,20 +17,60 @@ func MessageGenerator[T proto.Message](x T, options GeneratorOptions) *rapid.Gen return rapid.Custom(func(t *rapid.T) T { msg := msgType.New() - options.setFields(t, msg, 0) + options.setFields(t, nil, msg, 0) return msg.Interface().(T) }) } +// FieldMapper is a function that can be used to override the default behavior of the generator for a specific field. +// The first argument is the rapid.T, the second is the field descriptor, and the third is the field name. +// If the function returns nil, the default behavior will be used. +type FieldMapper func(*rapid.T, protoreflect.FieldDescriptor, string) (protoreflect.Value, bool) + type GeneratorOptions struct { - AnyTypeURLs []string - Resolver protoregistry.MessageTypeResolver + AnyTypeURLs []string + InterfaceHints map[string]string + Resolver protoregistry.MessageTypeResolver + + // NoEmptyLists will cause the generator to not generate empty lists + // Recall that an empty list will marshal (and unmarshal) to null. Some encodings may treat these states + // differently. For example, in JSON, an empty list is encoded as [], while null is encoded as null. + NoEmptyLists bool + + // DisallowNilMessages will cause the generator to not generate nil messages to protoreflect.MessageKind fields + DisallowNilMessages bool + + // FieldMaps is a list of FieldMapper functions that can be used to override the default behavior of the generator + // for a specific field. + FieldMaps []FieldMapper } const depthLimit = 10 -func (opts GeneratorOptions) setFields(t *rapid.T, msg protoreflect.Message, depth int) bool { +func (opts GeneratorOptions) WithAnyTypes(anyTypes ...proto.Message) GeneratorOptions { + for _, a := range anyTypes { + opts.AnyTypeURLs = append(opts.AnyTypeURLs, fmt.Sprintf("/%s", a.ProtoReflect().Descriptor().FullName())) + } + return opts +} + +func (opts GeneratorOptions) WithDisallowNil() GeneratorOptions { + o := &opts + o.DisallowNilMessages = true + return *o +} + +func (opts GeneratorOptions) WithInterfaceHint(i string, impl proto.Message) GeneratorOptions { + if opts.InterfaceHints == nil { + opts.InterfaceHints = make(map[string]string) + } + opts.InterfaceHints[i] = string(impl.ProtoReflect().Descriptor().FullName()) + return opts +} + +func (opts GeneratorOptions) setFields( + t *rapid.T, field protoreflect.FieldDescriptor, msg protoreflect.Message, depth int) bool { // to avoid stack overflow we limit the depth of nested messages if depth > depthLimit { return false @@ -44,7 +86,8 @@ func (opts GeneratorOptions) setFields(t *rapid.T, msg protoreflect.Message, dep opts.genDuration(t, msg) return true case anyFullName: - return opts.genAny(t, msg, depth) + opts.genAny(t, field, msg, depth) + return true case fieldMaskFullName: opts.genFieldMask(t, msg) return true @@ -52,12 +95,14 @@ func (opts GeneratorOptions) setFields(t *rapid.T, msg protoreflect.Message, dep fields := descriptor.Fields() n := fields.Len() for i := 0; i < n; i++ { - field := fields.Get(i) - if !rapid.Bool().Draw(t, fmt.Sprintf("gen-%s", field.Name())) { - continue + f := fields.Get(i) + if !rapid.Bool().Draw(t, fmt.Sprintf("gen-%s", f.Name())) { + if (f.Kind() == protoreflect.MessageKind) && !opts.DisallowNilMessages { + continue + } } - opts.setFieldValue(t, msg, field, depth) + opts.setFieldValue(t, msg, f, depth) } return true } @@ -77,10 +122,14 @@ func (opts GeneratorOptions) setFieldValue(t *rapid.T, msg protoreflect.Message, switch { case field.IsList(): list := msg.Mutable(field).List() - n := rapid.IntRange(0, 10).Draw(t, fmt.Sprintf("%sN", name)) + min := 0 + if opts.NoEmptyLists { + min = 1 + } + n := rapid.IntRange(min, 10).Draw(t, fmt.Sprintf("%sN", name)) for i := 0; i < n; i++ { if kind == protoreflect.MessageKind || kind == protoreflect.GroupKind { - if !opts.setFields(t, list.AppendMutable().Message(), depth+1) { + if !opts.setFields(t, field, list.AppendMutable().Message(), depth+1) { list.Truncate(i) } } else { @@ -96,7 +145,7 @@ func (opts GeneratorOptions) setFieldValue(t *rapid.T, msg protoreflect.Message, valueKind := valueField.Kind() key := opts.genScalarFieldValue(t, keyField, fmt.Sprintf("%s%d-key", name, i)) if valueKind == protoreflect.MessageKind || valueKind == protoreflect.GroupKind { - if !opts.setFields(t, m.Mutable(key.MapKey()).Message(), depth+1) { + if !opts.setFields(t, field, m.Mutable(key.MapKey()).Message(), depth+1) { m.Clear(key.MapKey()) } } else { @@ -104,18 +153,31 @@ func (opts GeneratorOptions) setFieldValue(t *rapid.T, msg protoreflect.Message, m.Set(key.MapKey(), value) } } - default: - if kind == protoreflect.MessageKind || kind == protoreflect.GroupKind { - if !opts.setFields(t, msg.Mutable(field).Message(), depth+1) { + case kind == protoreflect.MessageKind: + mutableField := msg.Mutable(field) + if mutableField.Message().Descriptor().FullName() == anyFullName { + if !opts.genAny(t, field, mutableField.Message(), depth+1) { msg.Clear(field) } - } else { - msg.Set(field, opts.genScalarFieldValue(t, field, name)) + } else if !opts.setFields(t, field, mutableField.Message(), depth+1) { + msg.Clear(field) } + case kind == protoreflect.GroupKind: + if !opts.setFields(t, field, msg.Mutable(field).Message(), depth+1) { + msg.Clear(field) + } + default: + msg.Set(field, opts.genScalarFieldValue(t, field, name)) } } func (opts GeneratorOptions) genScalarFieldValue(t *rapid.T, field protoreflect.FieldDescriptor, name string) protoreflect.Value { + for _, fm := range opts.FieldMaps { + if v, ok := fm(t, field, name); ok { + return v + } + } + switch field.Kind() { case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind: return protoreflect.ValueOfInt32(rapid.Int32().Draw(t, name)) @@ -146,8 +208,11 @@ func (opts GeneratorOptions) genScalarFieldValue(t *rapid.T, field protoreflect. } const ( - secondsName = "seconds" - nanosName = "nanos" + // MaxDurationSeconds the maximum number of seconds (when expressed as nanoseconds) which can fit in an int64. + // gogoproto encodes google.protobuf.Duration as a time.Duration, which is 64-bit signed integer. + MaxDurationSeconds = int64(math.MaxInt64/int(1e9)) - 1 + secondsName = "seconds" + nanosName = "nanos" ) func (opts GeneratorOptions) genTimestamp(t *rapid.T, msg protoreflect.Message) { @@ -157,7 +222,7 @@ func (opts GeneratorOptions) genTimestamp(t *rapid.T, msg protoreflect.Message) } func (opts GeneratorOptions) genDuration(t *rapid.T, msg protoreflect.Message) { - seconds := rapid.Int64Range(0, 315576000000).Draw(t, "seconds") + seconds := rapid.Int64Range(0, int64(MaxDurationSeconds)).Draw(t, "seconds") nanos := rapid.Int32Range(0, 999999999).Draw(t, "nanos") setSecondsNanosFields(t, msg, seconds, nanos) } @@ -179,23 +244,35 @@ const ( valueName = "value" ) -func (opts GeneratorOptions) genAny(t *rapid.T, msg protoreflect.Message, depth int) bool { +func (opts GeneratorOptions) genAny( + t *rapid.T, field protoreflect.FieldDescriptor, msg protoreflect.Message, depth int) bool { if len(opts.AnyTypeURLs) == 0 { return false } - fields := msg.Descriptor().Fields() + var typeURL string + fopts := field.Options() + if proto.HasExtension(fopts, cosmos_proto.E_AcceptsInterface) { + ai := proto.GetExtension(fopts, cosmos_proto.E_AcceptsInterface).(string) + if impl, found := opts.InterfaceHints[ai]; found { + typeURL = fmt.Sprintf("/%s", impl) + } else { + panic(fmt.Sprintf("no implementation found for interface %s", ai)) + } + } else { + typeURL = rapid.SampledFrom(opts.AnyTypeURLs).Draw(t, "type_url") + } - typeURL := rapid.SampledFrom(opts.AnyTypeURLs).Draw(t, "type_url") typ, err := opts.Resolver.FindMessageByURL(typeURL) assert.NilError(t, err) + fields := msg.Descriptor().Fields() typeURLField := fields.ByName(typeURLName) assert.Assert(t, typeURLField != nil) msg.Set(typeURLField, protoreflect.ValueOfString(typeURL)) valueMsg := typ.New() - opts.setFields(t, valueMsg, depth+1) + opts.setFields(t, nil, valueMsg, depth+1) valueBz, err := proto.Marshal(valueMsg.Interface()) assert.NilError(t, err) diff --git a/rapidproto/rapidproto_test.go b/rapidproto/rapidproto_test.go index 7f76a9b..132876e 100644 --- a/rapidproto/rapidproto_test.go +++ b/rapidproto/rapidproto_test.go @@ -4,13 +4,13 @@ import ( "fmt" "testing" - "github.com/cosmos/cosmos-proto/rapidproto" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" "gotest.tools/v3/assert" "gotest.tools/v3/golden" "pgregory.net/rapid" + "github.com/cosmos/cosmos-proto/rapidproto" "github.com/cosmos/cosmos-proto/testpb" ) @@ -19,7 +19,7 @@ import ( // to generally look good. func TestRegression(t *testing.T) { gen := rapidproto.MessageGenerator(&testpb.A{}, rapidproto.GeneratorOptions{}) - for i := 0; i < 5; i++ { + for i := 1000; i < 1005; i++ { testRegressionSeed(t, i, gen) } } diff --git a/rapidproto/testdata/seed0.json b/rapidproto/testdata/seed0.json deleted file mode 100644 index fc8ba93..0000000 --- a/rapidproto/testdata/seed0.json +++ /dev/null @@ -1 +0,0 @@ -{"enum":"Two", "someBoolean":true, "INT32":6, "SINT32":-53, "INT64":"-261", "SFIXED32":3, "FIXED32":65302, "FIXED64":"45044", "STRING":"󳲠~Âaႃ#", "MESSAGE":{"x":"ʰ="}, "MAP":{"":{"x":"௹"}, "%󠇯º$&.":{"x":"-"}, "=A":{}, "AA|𞀠":{"x":"a\u0000ๆ"}}, "LIST":[{}], "ONEOFSTRING":"", "imported":{}} \ No newline at end of file diff --git a/rapidproto/testdata/seed1.json b/rapidproto/testdata/seed1.json deleted file mode 100644 index 881e8d3..0000000 --- a/rapidproto/testdata/seed1.json +++ /dev/null @@ -1 +0,0 @@ -{"UINT32":177, "INT64":"-139958413", "SFIXED32":41418, "FIXED32":25381940, "FLOAT":-8.336453e+31, "SFIXED64":"-2503553836720", "DOUBLE":-0.03171187036377887, "STRING":"?˄~ע", "MESSAGE":{"x":"dDž#"}, "MAP":{"Ⱥa<":{"x":"+["}, "֑Ⱥ|@!`":{}}, "ONEOFSTRING":"\u0012\t?A", "imported":{}, "type":"A�=*ى~~‮Ⱥ*ᾈാȺAᶊ?"} \ No newline at end of file diff --git a/rapidproto/testdata/seed1000.json b/rapidproto/testdata/seed1000.json new file mode 100644 index 0000000..4294f9d --- /dev/null +++ b/rapidproto/testdata/seed1000.json @@ -0,0 +1 @@ +{"someBoolean":true, "INT32":112, "SINT32":20861, "UINT32":39, "INT64":"1916", "SING64":"-208", "UINT64":"95", "SFIXED32":-4267293, "FIXED32":56, "FLOAT":-0.061933517, "SFIXED64":"-128961679122", "FIXED64":"56609281", "DOUBLE":-4.2137932902833315e+218, "STRING":"a", "BYTES":"EQzPLBA1mwUBCQ==", "MAP":{"ž󠀶×a)?c":{"x":"Ʉ⃞➋-"}}, "ONEOFSTRING":"Aা𑅅҉^*[ॎ{؂Ⱥ\n@^