From b73bc7f225003b1de0e29013993ed80f092940c2 Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Tue, 27 Jun 2023 16:28:17 +0100 Subject: [PATCH] Drain stream and error on trailing data --- connect_ext_test.go | 14 ++++++++++++++ duplex_http_call.go | 2 +- envelope.go | 6 ++++++ protocol.go | 8 +++----- protocol_grpc.go | 2 +- 5 files changed, 25 insertions(+), 7 deletions(-) diff --git a/connect_ext_test.go b/connect_ext_test.go index 5c44e95a..16e09756 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -2138,6 +2138,20 @@ func TestStreamUnexpectedEOF(t *testing.T) { }, expectCode: connect.CodeInvalidArgument, expectMsg: "invalid_argument: protocol error: incomplete envelope: unexpected EOF", + }, { + name: "stream_excess_eof", + handler: func(responseWriter http.ResponseWriter, request *http.Request) { + _, _ = responseWriter.Write(head[:]) + _, _ = responseWriter.Write(payload) + // Write EOF + _, _ = responseWriter.Write([]byte{2, 0, 0, 0, 2}) + _, _ = responseWriter.Write([]byte("{}")) + // Excess payload + _, _ = responseWriter.Write(head[:]) + _, _ = responseWriter.Write(payload) + }, + expectCode: connect.CodeUnknown, + expectMsg: fmt.Sprintf("unknown: corrupt response: %d extra bytes after end of stream", len(payload)+len(head)), }} for _, testcase := range testcases { testcaseMux[t.Name()+"/"+testcase.name] = testcase.handler diff --git a/duplex_http_call.go b/duplex_http_call.go index 4dac0092..439f55c2 100644 --- a/duplex_http_call.go +++ b/duplex_http_call.go @@ -179,7 +179,7 @@ func (d *duplexHTTPCall) CloseRead() error { if d.response == nil { return nil } - if err := discard(d.response.Body); err != nil { + if _, err := discard(d.response.Body); err != nil { _ = d.response.Body.Close() return wrapIfRSTError(err) } diff --git a/envelope.go b/envelope.go index 36e21b9a..9330a946 100644 --- a/envelope.go +++ b/envelope.go @@ -209,6 +209,12 @@ func (r *envelopeReader) Unmarshal(message any) *Error { Data: bytes.NewBuffer(copiedData), Flags: env.Flags, } + // Drain the rest of the stream to ensure there is no extra data. + if n, err := discard(r.reader); err != nil { + return errorf(CodeUnknown, "corrupt response: I/O error after end-stream message: %w", err) + } else if n > 0 { + return errorf(CodeUnknown, "corrupt response: %d extra bytes after end of stream", n) + } return errSpecialEnvelope } diff --git a/protocol.go b/protocol.go index 8486ca96..a02f24b0 100644 --- a/protocol.go +++ b/protocol.go @@ -283,16 +283,14 @@ func isCommaOrSpace(c rune) bool { return c == ',' || c == ' ' } -func discard(reader io.Reader) error { +func discard(reader io.Reader) (int64, error) { if lr, ok := reader.(*io.LimitedReader); ok { - _, err := io.Copy(io.Discard, lr) - return err + return io.Copy(io.Discard, lr) } // We don't want to get stuck throwing data away forever, so limit how much // we're willing to do here. lr := &io.LimitedReader{R: reader, N: discardLimit} - _, err := io.Copy(io.Discard, lr) - return err + return io.Copy(io.Discard, lr) } // negotiateCompression determines and validates the request compression and diff --git a/protocol_grpc.go b/protocol_grpc.go index 62e3355e..e3302a89 100644 --- a/protocol_grpc.go +++ b/protocol_grpc.go @@ -326,7 +326,7 @@ func (g *grpcClient) NewConn( } else { conn.readTrailers = func(_ *grpcUnmarshaler, call *duplexHTTPCall) http.Header { // To access HTTP trailers, we need to read the body to EOF. - _ = discard(call) + _, _ = discard(call) return call.ResponseTrailer() } }