diff --git a/proto/all_test.go b/proto/all_test.go index a9da89fca76c..e7c1ff9f7daa 100644 --- a/proto/all_test.go +++ b/proto/all_test.go @@ -1860,6 +1860,25 @@ func TestRequiredNotSetError(t *testing.T) { } } +func TestRequiredNotSetErrorWithBadWireTypes(t *testing.T) { + // Required field expects a varint, and properly found a varint. + if err := Unmarshal([]byte{0x08, 0x00}, new(GoEnum)); err != nil { + t.Errorf("Unmarshal = %v, want nil", err) + } + // Required field expects a varint, but found a fixed32 instead. + if err := Unmarshal([]byte{0x0d, 0x00, 0x00, 0x00, 0x00}, new(GoEnum)); err == nil { + t.Errorf("Unmarshal = nil, want RequiredNotSetError") + } + // Required field expects a varint, and found both a varint and fixed32 (ignored). + m := new(GoEnum) + if err := Unmarshal([]byte{0x08, 0x00, 0x0d, 0x00, 0x00, 0x00, 0x00}, m); err != nil { + t.Errorf("Unmarshal = %v, want nil", err) + } + if !bytes.Equal(m.XXX_unrecognized, []byte{0x0d, 0x00, 0x00, 0x00, 0x00}) { + t.Errorf("expected fixed32 to appear as unknown bytes: %x", m.XXX_unrecognized) + } +} + func fuzzUnmarshal(t *testing.T, data []byte) { defer func() { if e := recover(); e != nil { diff --git a/proto/table_unmarshal.go b/proto/table_unmarshal.go index 8c0c10db5103..55f0340a3fde 100644 --- a/proto/table_unmarshal.go +++ b/proto/table_unmarshal.go @@ -166,20 +166,21 @@ func (u *unmarshalInfo) unmarshal(m pointer, b []byte) error { } else { f = u.sparse[tag] } - reqMask |= f.reqMask if fn := f.unmarshal; fn != nil { var err error b, err = fn(b, m.offset(f.field), wire) if err == nil { + reqMask |= f.reqMask continue } if r, ok := err.(*RequiredNotSetError); ok { // Remember this error, but keep parsing. We need to produce // a full parse even if a required field is missing. rnse = r + reqMask |= f.reqMask continue } - if err == nil || err != errInternalBadWireType { + if err != errInternalBadWireType { return err } // Fragments with bad wire type are treated as unknown fields.