diff --git a/proto/all_test.go b/proto/all_test.go index 60a1444a7589..573410885bfb 100644 --- a/proto/all_test.go +++ b/proto/all_test.go @@ -2184,6 +2184,21 @@ func TestConcurrentMarshal(t *testing.T) { } } +func TestInvalidUTF8(t *testing.T) { + const wire = "\x12\x04\xde\xea\xca\xfe" + + var m GoTest + if err := Unmarshal([]byte(wire), &m); err == nil { + t.Errorf("Unmarshal error: got nil, want non-nil") + } + + m.Reset() + m.Table = String(wire[2:]) + if _, err := Marshal(&m); err == nil { + t.Errorf("Marshal error: got nil, want non-nil") + } +} + // Benchmarks func testMsg() *GoTest { diff --git a/proto/lib.go b/proto/lib.go index 7d8796545e1c..a1e22cf50074 100644 --- a/proto/lib.go +++ b/proto/lib.go @@ -265,6 +265,7 @@ package proto import ( "encoding/json" + "errors" "fmt" "log" "reflect" @@ -279,6 +280,8 @@ import ( // This variable is temporary and will go away soon. Do not rely on it. var Proto3UnknownFields = false +var errInvalidUTF8 = errors.New("proto: invalid UTF-8 string") + // Message is implemented by generated protocol buffer messages. type Message interface { Reset() diff --git a/proto/table_marshal.go b/proto/table_marshal.go index a5b6565969fa..8097715a3e36 100644 --- a/proto/table_marshal.go +++ b/proto/table_marshal.go @@ -41,6 +41,7 @@ import ( "strings" "sync" "sync/atomic" + "unicode/utf8" ) // a sizer takes a pointer to a field and the size of its tag, computes the size of @@ -1983,6 +1984,9 @@ func appendBoolPackedSlice(b []byte, ptr pointer, wiretag uint64, _ bool) ([]byt } func appendStringValue(b []byte, ptr pointer, wiretag uint64, _ bool) ([]byte, error) { v := *ptr.toString() + if !utf8.ValidString(v) { + return nil, errInvalidUTF8 + } b = appendVarint(b, wiretag) b = appendVarint(b, uint64(len(v))) b = append(b, v...) @@ -1993,6 +1997,9 @@ func appendStringValueNoZero(b []byte, ptr pointer, wiretag uint64, _ bool) ([]b if v == "" { return b, nil } + if !utf8.ValidString(v) { + return nil, errInvalidUTF8 + } b = appendVarint(b, wiretag) b = appendVarint(b, uint64(len(v))) b = append(b, v...) @@ -2004,6 +2011,9 @@ func appendStringPtr(b []byte, ptr pointer, wiretag uint64, _ bool) ([]byte, err return b, nil } v := *p + if !utf8.ValidString(v) { + return nil, errInvalidUTF8 + } b = appendVarint(b, wiretag) b = appendVarint(b, uint64(len(v))) b = append(b, v...) @@ -2012,6 +2022,9 @@ func appendStringPtr(b []byte, ptr pointer, wiretag uint64, _ bool) ([]byte, err func appendStringSlice(b []byte, ptr pointer, wiretag uint64, _ bool) ([]byte, error) { s := *ptr.toStringSlice() for _, v := range s { + if !utf8.ValidString(v) { + return nil, errInvalidUTF8 + } b = appendVarint(b, wiretag) b = appendVarint(b, uint64(len(v))) b = append(b, v...) diff --git a/proto/table_unmarshal.go b/proto/table_unmarshal.go index bfd8d909e4b1..a0b52486e968 100644 --- a/proto/table_unmarshal.go +++ b/proto/table_unmarshal.go @@ -41,6 +41,7 @@ import ( "strings" "sync" "sync/atomic" + "unicode/utf8" ) // Unmarshal is the entry point from the generated .pb.go files. @@ -1514,6 +1515,9 @@ func unmarshalStringValue(b []byte, f pointer, w int) ([]byte, error) { return nil, io.ErrUnexpectedEOF } v := string(b[:x]) + if !utf8.ValidString(v) { + return nil, errInvalidUTF8 + } *f.toString() = v return b[x:], nil } @@ -1531,6 +1535,9 @@ func unmarshalStringPtr(b []byte, f pointer, w int) ([]byte, error) { return nil, io.ErrUnexpectedEOF } v := string(b[:x]) + if !utf8.ValidString(v) { + return nil, errInvalidUTF8 + } *f.toStringPtr() = &v return b[x:], nil } @@ -1548,6 +1555,9 @@ func unmarshalStringSlice(b []byte, f pointer, w int) ([]byte, error) { return nil, io.ErrUnexpectedEOF } v := string(b[:x]) + if !utf8.ValidString(v) { + return nil, errInvalidUTF8 + } s := f.toStringSlice() *s = append(*s, v) return b[x:], nil