Skip to content

Commit

Permalink
proto: add logic to handle legacy message (protocolbuffers#496)
Browse files Browse the repository at this point in the history
Not all proto.Message implementations will be updated to be
using the most recent protoc-gen-go. Thus, they will lack an
XXX_DiscardUnknown method. Add logic to handle older protobufs.
  • Loading branch information
dsnet authored Jan 25, 2018
1 parent f9bf3fb commit 10c2d9d
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 15 deletions.
120 changes: 106 additions & 14 deletions proto/discard.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,10 @@ func DiscardUnknown(m Message) {
m.XXX_DiscardUnknown()
return
}
if m == nil {
return
}

// The Message interface really needs to provide some form of reflection
// API that can be used to implement the fallback if someone is using
// a custom protobuf message implementation.
//
// See https://github.com/golang/protobuf/issues/364
panic(fmt.Sprintf("cannot discard unknown fields on %T", m))
// TODO: Dynamically populate a InternalMessageInfo for legacy messages,
// but the master branch has no implementation for InternalMessageInfo,
// so it would be more work to replicate that approach.
discardLegacy(m)
}

// DiscardUnknown recursively discards all unknown fields.
Expand Down Expand Up @@ -172,14 +166,14 @@ func (di *discardInfo) computeDiscardInfo() {
tf = tf.Elem()
}
if isPointer && isSlice && tf.Kind() != reflect.Struct {
panic("both pointer and slice for basic type in " + tf.Name())
panic(fmt.Sprintf("%v.%s cannot be a slice of pointers to primitive types", t, f.Name))
}

switch tf.Kind() {
case reflect.Struct:
switch {
case !isPointer:
panic(fmt.Sprintf("message field %s without pointer", tf))
panic(fmt.Sprintf("%v.%s cannot be a direct struct value", t, f.Name))
case isSlice: // E.g., []*pb.T
di := getDiscardInfo(tf)
dfi.discard = func(src pointer) {
Expand All @@ -202,7 +196,7 @@ func (di *discardInfo) computeDiscardInfo() {
case reflect.Map:
switch {
case isPointer || isSlice:
panic("bad pointer or slice in map case in " + tf.Name())
panic(fmt.Sprintf("%v.%s cannot be a pointer to a map or a slice of map values", t, f.Name))
default: // E.g., map[K]V
if tf.Elem().Kind() == reflect.Ptr { // Proto struct (e.g., *T)
dfi.discard = func(src pointer) {
Expand All @@ -223,7 +217,7 @@ func (di *discardInfo) computeDiscardInfo() {
// Must be oneof field.
switch {
case isPointer || isSlice:
panic("bad pointer or slice in interface case in " + tf.Name())
panic(fmt.Sprintf("%v.%s cannot be a pointer to a interface or a slice of interface values", t, f.Name))
default: // E.g., interface{}
// TODO: Make this faster?
dfi.discard = func(src pointer) {
Expand Down Expand Up @@ -256,3 +250,101 @@ func (di *discardInfo) computeDiscardInfo() {

atomic.StoreInt32(&di.initialized, 1)
}

func discardLegacy(m Message) {
v := reflect.ValueOf(m)
if v.Kind() != reflect.Ptr || v.IsNil() {
return
}
v = v.Elem()
if v.Kind() != reflect.Struct {
return
}
t := v.Type()

for i := 0; i < v.NumField(); i++ {
f := t.Field(i)
if strings.HasPrefix(f.Name, "XXX_") {
continue
}
vf := v.Field(i)
tf := f.Type

// Unwrap tf to get its most basic type.
var isPointer, isSlice bool
if tf.Kind() == reflect.Slice && tf.Elem().Kind() != reflect.Uint8 {
isSlice = true
tf = tf.Elem()
}
if tf.Kind() == reflect.Ptr {
isPointer = true
tf = tf.Elem()
}
if isPointer && isSlice && tf.Kind() != reflect.Struct {
panic(fmt.Sprintf("%T.%s cannot be a slice of pointers to primitive types", m, f.Name))
}

switch tf.Kind() {
case reflect.Struct:
switch {
case !isPointer:
panic(fmt.Sprintf("%T.%s cannot be a direct struct value", m, f.Name))
case isSlice: // E.g., []*pb.T
for j := 0; j < vf.Len(); j++ {
discardLegacy(vf.Index(j).Interface().(Message))
}
default: // E.g., *pb.T
discardLegacy(vf.Interface().(Message))
}
case reflect.Map:
switch {
case isPointer || isSlice:
panic(fmt.Sprintf("%T.%s cannot be a pointer to a map or a slice of map values", m, f.Name))
default: // E.g., map[K]V
tv := vf.Type().Elem()
if tv.Kind() == reflect.Ptr && tv.Implements(protoMessageType) { // Proto struct (e.g., *T)
for _, key := range vf.MapKeys() {
val := vf.MapIndex(key)
discardLegacy(val.Interface().(Message))
}
}
}
case reflect.Interface:
// Must be oneof field.
switch {
case isPointer || isSlice:
panic(fmt.Sprintf("%T.%s cannot be a pointer to a interface or a slice of interface values", m, f.Name))
default: // E.g., test_proto.isCommunique_Union interface
if !vf.IsNil() && f.Tag.Get("protobuf_oneof") != "" {
vf = vf.Elem() // E.g., *test_proto.Communique_Msg
if !vf.IsNil() {
vf = vf.Elem() // E.g., test_proto.Communique_Msg
vf = vf.Field(0) // E.g., Proto struct (e.g., *T) or primitive value
if vf.Kind() == reflect.Ptr {
discardLegacy(vf.Interface().(Message))
}
}
}
}
}
}

if vf := v.FieldByName("XXX_unrecognized"); vf.IsValid() {
if vf.Type() != reflect.TypeOf([]byte{}) {
panic("expected XXX_unrecognized to be of type []byte")
}
vf.Set(reflect.ValueOf([]byte(nil)))
}

// For proto2 messages, only discard unknown fields in message extensions
// that have been accessed via GetExtension.
if em, err := extendable(m); err == nil {
// Ignore lock since discardLegacy is not concurrency safe.
emm, _ := em.extensionsRead()
for _, mx := range emm {
if m, ok := mx.value.(Message); ok {
discardLegacy(m)
}
}
}
}
52 changes: 51 additions & 1 deletion proto/discard_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,23 @@ func TestDiscardUnknown(t *testing.T) {
Name: "Aaron",
Nested: &proto3pb.Nested{Cute: true},
},
}, {
desc: "Slice",
in: &proto3pb.Message{
Name: "Aaron",
Children: []*proto3pb.Message{
{Name: "Sarah", XXX_unrecognized: []byte("blah")},
{Name: "Abraham", XXX_unrecognized: []byte("blah")},
},
XXX_unrecognized: []byte("blah"),
},
want: &proto3pb.Message{
Name: "Aaron",
Children: []*proto3pb.Message{
{Name: "Sarah"},
{Name: "Abraham"},
},
},
}, {
desc: "OneOf",
in: &pb.Communique{
Expand Down Expand Up @@ -111,10 +128,43 @@ func TestDiscardUnknown(t *testing.T) {
}(),
}}

// Test the legacy code path.
for _, tt := range tests {
// Clone the input so that we don't alter the original.
in := tt.in
if in != nil {
in = proto.Clone(tt.in)
}

var m LegacyMessage
m.Message, _ = in.(*proto3pb.Message)
m.Communique, _ = in.(*pb.Communique)
m.MessageWithMap, _ = in.(*pb.MessageWithMap)
m.MyMessage, _ = in.(*pb.MyMessage)
proto.DiscardUnknown(&m)
if !proto.Equal(in, tt.want) {
t.Errorf("test %s/Legacy, expected unknown fields to be discarded\ngot %v\nwant %v", tt.desc, in, tt.want)
}
}

for _, tt := range tests {
proto.DiscardUnknown(tt.in)
if !proto.Equal(tt.in, tt.want) {
t.Errorf("test %s, expected unknown fields to be discarde\ngot %v\nwant %v", tt.desc, tt.in, tt.want)
t.Errorf("test %s, expected unknown fields to be discarded\ngot %v\nwant %v", tt.desc, tt.in, tt.want)
}
}
}

// LegacyMessage is a proto.Message that has several nested messages.
// This does not have the XXX_DiscardUnknown method and so forces DiscardUnknown
// to use the legacy fallback logic.
type LegacyMessage struct {
Message *proto3pb.Message
Communique *pb.Communique
MessageWithMap *pb.MessageWithMap
MyMessage *pb.MyMessage
}

func (m *LegacyMessage) Reset() { *m = LegacyMessage{} }
func (m *LegacyMessage) String() string { return proto.CompactTextString(m) }
func (*LegacyMessage) ProtoMessage() {}

0 comments on commit 10c2d9d

Please sign in to comment.