Skip to content

Commit

Permalink
refactor(codec/unknownproto): use official descriptorpb types, do not…
Browse files Browse the repository at this point in the history
… rely on gogo code gen types. (#18541)

Co-authored-by: unknown unknown <unknown@unknown>
  • Loading branch information
testinginprod and unknown unknown authored Nov 29, 2023
1 parent 7bdbc46 commit ae7288b
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 46 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ Ref: https://keepachangelog.com/en/1.0.0/
* [#17733](https://github.com/cosmos/cosmos-sdk/pull/17733) Ensure `buf export` exports all proto dependencies
* (version) [#18063](https://github.com/cosmos/cosmos-sdk/pull/18063) Include additional information in the Info struct. This change enhances the Info struct by adding support for additional information through the ExtraInfo field
* (crypto | x/auth) [#14372](https://github.com/cosmos/cosmos-sdk/pull/18194) Key checks on signatures antehandle.
* (codec/unknownproto)[#18541](https://github.com/cosmos/cosmos-sdk/pull/18541) Remove the use of "protoc-gen-gogo/descriptor" in favour of using the official protobuf descriptorpb types inside unknownproto.
* (staking) [#18506](https://github.com/cosmos/cosmos-sdk/pull/18506) Detect the length of the ed25519 pubkey in CreateValidator to prevent panic.

### Bug Fixes
Expand Down
158 changes: 112 additions & 46 deletions codec/unknownproto/unknown_fields.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ import (

"github.com/cosmos/gogoproto/jsonpb"
"github.com/cosmos/gogoproto/proto"
"github.com/cosmos/gogoproto/protoc-gen-gogo/descriptor"
"google.golang.org/protobuf/encoding/protowire"
protov2 "google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/descriptorpb"

"github.com/cosmos/cosmos-sdk/codec/types"
)
Expand Down Expand Up @@ -68,7 +69,7 @@ func RejectUnknownFields(bz []byte, msg proto.Message, allowUnknownNonCriticals
Type: reflect.ValueOf(msg).Type().String(),
TagNum: tagNum,
GotWireType: wireType,
WantWireType: protowire.Type(fieldDescProto.WireType()),
WantWireType: toProtowireType(fieldDescProto.GetType()),
}
}

Expand Down Expand Up @@ -101,14 +102,14 @@ func RejectUnknownFields(bz []byte, msg proto.Message, allowUnknownNonCriticals
bz = bz[n:]

// An unknown but non-critical field or just a scalar type (aka *INT and BYTES like).
if fieldDescProto == nil || fieldDescProto.IsScalar() {
if fieldDescProto == nil || isScalar(fieldDescProto) {
continue
}

protoMessageName := fieldDescProto.GetTypeName()
if protoMessageName == "" {
switch typ := fieldDescProto.GetType(); typ {
case descriptor.FieldDescriptorProto_TYPE_STRING, descriptor.FieldDescriptorProto_TYPE_BYTES:
case descriptorpb.FieldDescriptorProto_TYPE_STRING, descriptorpb.FieldDescriptorProto_TYPE_BYTES:
// At this point only TYPE_STRING is expected to be unregistered, since FieldDescriptorProto.IsScalar() returns false for
// TYPE_BYTES and TYPE_STRING as per
// https://github.com/cosmos/gogoproto/blob/5628607bb4c51c3157aacc3a50f0ab707582b805/protoc-gen-gogo/descriptor/descriptor.go#L95-L118
Expand Down Expand Up @@ -199,68 +200,68 @@ func protoMessageForTypeName(protoMessageName string) (proto.Message, error) {
// checks is a mapping of protowire.Type to supported descriptor.FieldDescriptorProto_Type.
// it is implemented this way so as to have constant time lookups and avoid the overhead
// from O(n) walking of switch. The change to using this mapping boosts throughput by about 200%.
var checks = [...]map[descriptor.FieldDescriptorProto_Type]bool{
var checks = [...]map[descriptorpb.FieldDescriptorProto_Type]bool{
// "0 Varint: int32, int64, uint32, uint64, sint32, sint64, bool, enum"
0: {
descriptor.FieldDescriptorProto_TYPE_INT32: true,
descriptor.FieldDescriptorProto_TYPE_INT64: true,
descriptor.FieldDescriptorProto_TYPE_UINT32: true,
descriptor.FieldDescriptorProto_TYPE_UINT64: true,
descriptor.FieldDescriptorProto_TYPE_SINT32: true,
descriptor.FieldDescriptorProto_TYPE_SINT64: true,
descriptor.FieldDescriptorProto_TYPE_BOOL: true,
descriptor.FieldDescriptorProto_TYPE_ENUM: true,
descriptorpb.FieldDescriptorProto_TYPE_INT32: true,
descriptorpb.FieldDescriptorProto_TYPE_INT64: true,
descriptorpb.FieldDescriptorProto_TYPE_UINT32: true,
descriptorpb.FieldDescriptorProto_TYPE_UINT64: true,
descriptorpb.FieldDescriptorProto_TYPE_SINT32: true,
descriptorpb.FieldDescriptorProto_TYPE_SINT64: true,
descriptorpb.FieldDescriptorProto_TYPE_BOOL: true,
descriptorpb.FieldDescriptorProto_TYPE_ENUM: true,
},

// "1 64-bit: fixed64, sfixed64, double"
1: {
descriptor.FieldDescriptorProto_TYPE_FIXED64: true,
descriptor.FieldDescriptorProto_TYPE_SFIXED64: true,
descriptor.FieldDescriptorProto_TYPE_DOUBLE: true,
descriptorpb.FieldDescriptorProto_TYPE_FIXED64: true,
descriptorpb.FieldDescriptorProto_TYPE_SFIXED64: true,
descriptorpb.FieldDescriptorProto_TYPE_DOUBLE: true,
},

// "2 Length-delimited: string, bytes, embedded messages, packed repeated fields"
2: {
descriptor.FieldDescriptorProto_TYPE_STRING: true,
descriptor.FieldDescriptorProto_TYPE_BYTES: true,
descriptor.FieldDescriptorProto_TYPE_MESSAGE: true,
descriptorpb.FieldDescriptorProto_TYPE_STRING: true,
descriptorpb.FieldDescriptorProto_TYPE_BYTES: true,
descriptorpb.FieldDescriptorProto_TYPE_MESSAGE: true,
// The following types can be packed repeated.
// ref: "Only repeated fields of primitive numeric types (types which use the varint, 32-bit, or 64-bit wire types) can be declared "packed"."
// ref: https://developers.google.com/protocol-buffers/docs/encoding#packed
descriptor.FieldDescriptorProto_TYPE_INT32: true,
descriptor.FieldDescriptorProto_TYPE_INT64: true,
descriptor.FieldDescriptorProto_TYPE_UINT32: true,
descriptor.FieldDescriptorProto_TYPE_UINT64: true,
descriptor.FieldDescriptorProto_TYPE_SINT32: true,
descriptor.FieldDescriptorProto_TYPE_SINT64: true,
descriptor.FieldDescriptorProto_TYPE_BOOL: true,
descriptor.FieldDescriptorProto_TYPE_ENUM: true,
descriptor.FieldDescriptorProto_TYPE_FIXED64: true,
descriptor.FieldDescriptorProto_TYPE_SFIXED64: true,
descriptor.FieldDescriptorProto_TYPE_DOUBLE: true,
descriptorpb.FieldDescriptorProto_TYPE_INT32: true,
descriptorpb.FieldDescriptorProto_TYPE_INT64: true,
descriptorpb.FieldDescriptorProto_TYPE_UINT32: true,
descriptorpb.FieldDescriptorProto_TYPE_UINT64: true,
descriptorpb.FieldDescriptorProto_TYPE_SINT32: true,
descriptorpb.FieldDescriptorProto_TYPE_SINT64: true,
descriptorpb.FieldDescriptorProto_TYPE_BOOL: true,
descriptorpb.FieldDescriptorProto_TYPE_ENUM: true,
descriptorpb.FieldDescriptorProto_TYPE_FIXED64: true,
descriptorpb.FieldDescriptorProto_TYPE_SFIXED64: true,
descriptorpb.FieldDescriptorProto_TYPE_DOUBLE: true,
},

// "3 Start group: groups (deprecated)"
3: {
descriptor.FieldDescriptorProto_TYPE_GROUP: true,
descriptorpb.FieldDescriptorProto_TYPE_GROUP: true,
},

// "4 End group: groups (deprecated)"
4: {
descriptor.FieldDescriptorProto_TYPE_GROUP: true,
descriptorpb.FieldDescriptorProto_TYPE_GROUP: true,
},

// "5 32-bit: fixed32, sfixed32, float"
5: {
descriptor.FieldDescriptorProto_TYPE_FIXED32: true,
descriptor.FieldDescriptorProto_TYPE_SFIXED32: true,
descriptor.FieldDescriptorProto_TYPE_FLOAT: true,
descriptorpb.FieldDescriptorProto_TYPE_FIXED32: true,
descriptorpb.FieldDescriptorProto_TYPE_SFIXED32: true,
descriptorpb.FieldDescriptorProto_TYPE_FLOAT: true,
},
}

// canEncodeType returns true if the wireType is suitable for encoding the descriptor type.
// See https://developers.google.com/protocol-buffers/docs/encoding#structure.
func canEncodeType(wireType protowire.Type, descType descriptor.FieldDescriptorProto_Type) bool {
func canEncodeType(wireType protowire.Type, descType descriptorpb.FieldDescriptorProto_Type) bool {
if iwt := int(wireType); iwt < 0 || iwt >= len(checks) {
return false
}
Expand Down Expand Up @@ -330,21 +331,21 @@ func (twt *errUnknownField) Error() string {
var _ error = (*errUnknownField)(nil)

var (
protoFileToDesc = make(map[string]*descriptor.FileDescriptorProto)
protoFileToDesc = make(map[string]*descriptorpb.FileDescriptorProto)
protoFileToDescMu sync.RWMutex
)

func unnestDesc(mdescs []*descriptor.DescriptorProto, indices []int) *descriptor.DescriptorProto {
func unnestDesc(mdescs []*descriptorpb.DescriptorProto, indices []int) *descriptorpb.DescriptorProto {
mdesc := mdescs[indices[0]]
for _, index := range indices[1:] {
mdesc = mdesc.NestedType[index]
}
return mdesc
}

// Invoking descriptor.ForMessage(proto.Message.(Descriptor).Descriptor()) is incredibly slow
// Invoking descriptorpb.ForMessage(proto.Message.(Descriptor).Descriptor()) is incredibly slow
// for every single message, thus the need for a hand-rolled custom version that's performant and cacheable.
func extractFileDescMessageDesc(desc descriptorIface) (*descriptor.FileDescriptorProto, *descriptor.DescriptorProto, error) {
func extractFileDescMessageDesc(desc descriptorIface) (*descriptorpb.FileDescriptorProto, *descriptorpb.DescriptorProto, error) {
gzippedPb, indices := desc.Descriptor()

protoFileToDescMu.RLock()
Expand All @@ -365,8 +366,8 @@ func extractFileDescMessageDesc(desc descriptorIface) (*descriptor.FileDescripto
return nil, nil, err
}

fdesc := new(descriptor.FileDescriptorProto)
if err := proto.Unmarshal(protoBlob, fdesc); err != nil {
fdesc := new(descriptorpb.FileDescriptorProto)
if err := protov2.Unmarshal(protoBlob, fdesc); err != nil {
return nil, nil, err
}

Expand All @@ -380,8 +381,8 @@ func extractFileDescMessageDesc(desc descriptorIface) (*descriptor.FileDescripto
}

type descriptorMatch struct {
cache map[int32]*descriptor.FieldDescriptorProto
desc *descriptor.DescriptorProto
cache map[int32]*descriptorpb.FieldDescriptorProto
desc *descriptorpb.DescriptorProto
}

var (
Expand All @@ -390,7 +391,7 @@ var (
)

// getDescriptorInfo retrieves the mapping of field numbers to their respective field descriptors.
func getDescriptorInfo(desc descriptorIface, msg proto.Message) (map[int32]*descriptor.FieldDescriptorProto, *descriptor.DescriptorProto, error) {
func getDescriptorInfo(desc descriptorIface, msg proto.Message) (map[int32]*descriptorpb.FieldDescriptorProto, *descriptorpb.DescriptorProto, error) {
key := reflect.ValueOf(msg).Type()

descprotoCacheMu.RLock()
Expand All @@ -407,7 +408,7 @@ func getDescriptorInfo(desc descriptorIface, msg proto.Message) (map[int32]*desc
return nil, nil, err
}

tagNumToTypeIndex := make(map[int32]*descriptor.FieldDescriptorProto)
tagNumToTypeIndex := make(map[int32]*descriptorpb.FieldDescriptorProto)
for _, field := range md.Field {
tagNumToTypeIndex[field.GetNumber()] = field
}
Expand Down Expand Up @@ -441,3 +442,68 @@ func (d DefaultAnyResolver) Resolve(typeURL string) (proto.Message, error) {
}
return reflect.New(mt.Elem()).Interface().(proto.Message), nil
}

// toProtowireType converts a descriptorpb.FieldDescriptorProto_Type to a protowire.Type.
func toProtowireType(fieldType descriptorpb.FieldDescriptorProto_Type) protowire.Type {
switch fieldType {
// varint encoded
case descriptorpb.FieldDescriptorProto_TYPE_INT64,
descriptorpb.FieldDescriptorProto_TYPE_UINT64,
descriptorpb.FieldDescriptorProto_TYPE_INT32,
descriptorpb.FieldDescriptorProto_TYPE_UINT32,
descriptorpb.FieldDescriptorProto_TYPE_BOOL,
descriptorpb.FieldDescriptorProto_TYPE_ENUM,
descriptorpb.FieldDescriptorProto_TYPE_SINT32,
descriptorpb.FieldDescriptorProto_TYPE_SINT64:
return protowire.VarintType

// fixed64 encoded
case descriptorpb.FieldDescriptorProto_TYPE_DOUBLE,
descriptorpb.FieldDescriptorProto_TYPE_FIXED64,
descriptorpb.FieldDescriptorProto_TYPE_SFIXED64:
return protowire.Fixed64Type

// fixed32 encoded
case descriptorpb.FieldDescriptorProto_TYPE_FLOAT,
descriptorpb.FieldDescriptorProto_TYPE_FIXED32,
descriptorpb.FieldDescriptorProto_TYPE_SFIXED32:
return protowire.Fixed32Type

// bytes encoded
case descriptorpb.FieldDescriptorProto_TYPE_STRING,
descriptorpb.FieldDescriptorProto_TYPE_BYTES,
descriptorpb.FieldDescriptorProto_TYPE_MESSAGE,
descriptorpb.FieldDescriptorProto_TYPE_GROUP:
return protowire.BytesType
default:
panic(fmt.Sprintf("unknown field type %s", fieldType))
}
}

// isScalar defines whether a field is a scalar type.
// Copied from gogo/protobuf/protoc-gen-gogo
// https://github.com/gogo/protobuf/blob/b03c65ea87cdc3521ede29f62fe3ce239267c1bc/protoc-gen-gogo/descriptor/descriptor.go#L95
func isScalar(field *descriptorpb.FieldDescriptorProto) bool {
if field.Type == nil {
return false
}
switch *field.Type {
case descriptorpb.FieldDescriptorProto_TYPE_DOUBLE,
descriptorpb.FieldDescriptorProto_TYPE_FLOAT,
descriptorpb.FieldDescriptorProto_TYPE_INT64,
descriptorpb.FieldDescriptorProto_TYPE_UINT64,
descriptorpb.FieldDescriptorProto_TYPE_INT32,
descriptorpb.FieldDescriptorProto_TYPE_FIXED64,
descriptorpb.FieldDescriptorProto_TYPE_FIXED32,
descriptorpb.FieldDescriptorProto_TYPE_BOOL,
descriptorpb.FieldDescriptorProto_TYPE_UINT32,
descriptorpb.FieldDescriptorProto_TYPE_ENUM,
descriptorpb.FieldDescriptorProto_TYPE_SFIXED32,
descriptorpb.FieldDescriptorProto_TYPE_SFIXED64,
descriptorpb.FieldDescriptorProto_TYPE_SINT32,
descriptorpb.FieldDescriptorProto_TYPE_SINT64:
return true
default:
return false
}
}

0 comments on commit ae7288b

Please sign in to comment.