diff --git a/connect_ext_test.go b/connect_ext_test.go index ec2b7121..a6d2b500 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -2115,8 +2115,10 @@ func TestStreamUnexpectedEOF(t *testing.T) { handler: func(responseWriter http.ResponseWriter, _ *http.Request) { header := responseWriter.Header() header.Set("Content-Type", "application/connect+json") - _, _ = responseWriter.Write(head[:]) - _, _ = responseWriter.Write(payload) + _, err := responseWriter.Write(head[:]) + assert.Nil(t, err) + _, err = responseWriter.Write(payload) + assert.Nil(t, err) }, expectCode: connect.CodeInternal, expectMsg: "internal: protocol error: unexpected EOF", @@ -2126,8 +2128,10 @@ func TestStreamUnexpectedEOF(t *testing.T) { handler: func(responseWriter http.ResponseWriter, _ *http.Request) { header := responseWriter.Header() header.Set("Content-Type", "application/grpc+json") - _, _ = responseWriter.Write(head[:]) - _, _ = responseWriter.Write(payload) + _, err := responseWriter.Write(head[:]) + assert.Nil(t, err) + _, err = responseWriter.Write(payload) + assert.Nil(t, err) }, expectCode: connect.CodeInternal, expectMsg: "internal: protocol error: no Grpc-Status trailer: unexpected EOF", @@ -2137,8 +2141,10 @@ func TestStreamUnexpectedEOF(t *testing.T) { handler: func(responseWriter http.ResponseWriter, _ *http.Request) { header := responseWriter.Header() header.Set("Content-Type", "application/grpc-web+json") - _, _ = responseWriter.Write(head[:]) + _, err := responseWriter.Write(head[:]) + assert.Nil(t, err) _, _ = responseWriter.Write(payload) + assert.Nil(t, err) }, expectCode: connect.CodeInternal, expectMsg: "internal: protocol error: no Grpc-Status trailer: unexpected EOF", @@ -2148,8 +2154,10 @@ func TestStreamUnexpectedEOF(t *testing.T) { handler: func(responseWriter http.ResponseWriter, _ *http.Request) { header := responseWriter.Header() header.Set("Content-Type", "application/connect+json") - _, _ = responseWriter.Write(head[:]) - _, _ = responseWriter.Write(payload[:len(payload)-1]) + _, err := responseWriter.Write(head[:]) + assert.Nil(t, err) + _, err = responseWriter.Write(payload[:len(payload)-1]) + assert.Nil(t, err) }, expectCode: connect.CodeInvalidArgument, expectMsg: fmt.Sprintf("invalid_argument: protocol error: promised %d bytes in enveloped message, got %d bytes", len(payload), len(payload)-1), @@ -2159,8 +2167,10 @@ func TestStreamUnexpectedEOF(t *testing.T) { handler: func(responseWriter http.ResponseWriter, _ *http.Request) { header := responseWriter.Header() header.Set("Content-Type", "application/grpc+json") - _, _ = responseWriter.Write(head[:]) - _, _ = responseWriter.Write(payload[:len(payload)-1]) + _, err := responseWriter.Write(head[:]) + assert.Nil(t, err) + _, err = responseWriter.Write(payload[:len(payload)-1]) + assert.Nil(t, err) }, expectCode: connect.CodeInvalidArgument, expectMsg: fmt.Sprintf("invalid_argument: protocol error: promised %d bytes in enveloped message, got %d bytes", len(payload), len(payload)-1), @@ -2170,8 +2180,10 @@ func TestStreamUnexpectedEOF(t *testing.T) { handler: func(responseWriter http.ResponseWriter, _ *http.Request) { header := responseWriter.Header() header.Set("Content-Type", "application/grpc-web+json") - _, _ = responseWriter.Write(head[:]) - _, _ = responseWriter.Write(payload[:len(payload)-1]) + _, err := responseWriter.Write(head[:]) + assert.Nil(t, err) + _, err = responseWriter.Write(payload[:len(payload)-1]) + assert.Nil(t, err) }, expectCode: connect.CodeInvalidArgument, expectMsg: fmt.Sprintf("invalid_argument: protocol error: promised %d bytes in enveloped message, got %d bytes", len(payload), len(payload)-1), @@ -2181,7 +2193,8 @@ func TestStreamUnexpectedEOF(t *testing.T) { handler: func(responseWriter http.ResponseWriter, _ *http.Request) { header := responseWriter.Header() header.Set("Content-Type", "application/connect+json") - _, _ = responseWriter.Write(head[:4]) + _, err := responseWriter.Write(head[:4]) + assert.Nil(t, err) }, expectCode: connect.CodeInvalidArgument, expectMsg: "invalid_argument: protocol error: incomplete envelope: unexpected EOF", @@ -2191,7 +2204,8 @@ func TestStreamUnexpectedEOF(t *testing.T) { handler: func(responseWriter http.ResponseWriter, _ *http.Request) { header := responseWriter.Header() header.Set("Content-Type", "application/grpc+json") - _, _ = responseWriter.Write(head[:4]) + _, err := responseWriter.Write(head[:4]) + assert.Nil(t, err) }, expectCode: connect.CodeInvalidArgument, expectMsg: "invalid_argument: protocol error: incomplete envelope: unexpected EOF", @@ -2201,10 +2215,59 @@ func TestStreamUnexpectedEOF(t *testing.T) { handler: func(responseWriter http.ResponseWriter, _ *http.Request) { header := responseWriter.Header() header.Set("Content-Type", "application/grpc-web+json") - _, _ = responseWriter.Write(head[:4]) + _, err := responseWriter.Write(head[:4]) + assert.Nil(t, err) }, expectCode: connect.CodeInvalidArgument, expectMsg: "invalid_argument: protocol error: incomplete envelope: unexpected EOF", + }, { + name: "connect_excess_eof", + options: []connect.ClientOption{connect.WithProtoJSON()}, + handler: func(responseWriter http.ResponseWriter, _ *http.Request) { + _, err := responseWriter.Write(head[:]) + assert.Nil(t, err) + _, err = responseWriter.Write(payload) + assert.Nil(t, err) + // Write EOF + _, err = responseWriter.Write([]byte{1 << 1, 0, 0, 0, 2}) + assert.Nil(t, err) + _, err = responseWriter.Write([]byte("{}")) + assert.Nil(t, err) + // Excess payload + _, err = responseWriter.Write(head[:]) + assert.Nil(t, err) + _, err = responseWriter.Write(payload) + assert.Nil(t, err) + }, + expectCode: connect.CodeInternal, + expectMsg: fmt.Sprintf("internal: corrupt response: %d extra bytes after end of stream", len(payload)+len(head)), + }, { + name: "grpc-web_excess_eof", + options: []connect.ClientOption{connect.WithProtoJSON(), connect.WithGRPCWeb()}, + handler: func(responseWriter http.ResponseWriter, _ *http.Request) { + _, err := responseWriter.Write(head[:]) + assert.Nil(t, err) + _, err = responseWriter.Write(payload) + assert.Nil(t, err) + // Write EOF + var buf bytes.Buffer + trailer := http.Header{"grpc-status": []string{"0"}} + assert.Nil(t, trailer.Write(&buf)) + var head [5]byte + head[0] = 1 << 7 + binary.BigEndian.PutUint32(head[1:], uint32(buf.Len())) + _, err = responseWriter.Write(head[:]) + assert.Nil(t, err) + _, err = responseWriter.Write(buf.Bytes()) + assert.Nil(t, err) + // Excess payload + _, err = responseWriter.Write(head[:]) + assert.Nil(t, err) + _, err = responseWriter.Write(payload) + assert.Nil(t, err) + }, + expectCode: connect.CodeInternal, + expectMsg: fmt.Sprintf("internal: corrupt response: %d extra bytes after end of stream", len(payload)+len(head)), }} for _, testcase := range testcases { testcaseMux[t.Name()+"/"+testcase.name] = testcase.handler @@ -2223,11 +2286,10 @@ func TestStreamUnexpectedEOF(t *testing.T) { request.Header().Set("Test-Case", t.Name()) stream, err := client.CountUp(context.Background(), request) assert.Nil(t, err) - for stream.Receive() { + for i := 0; stream.Receive() && i < upTo; i++ { assert.Equal(t, stream.Msg().Number, 42) } assert.NotNil(t, stream.Err()) - t.Log(stream.Err()) assert.Equal(t, connect.CodeOf(stream.Err()), testcase.expectCode) assert.Equal(t, stream.Err().Error(), testcase.expectMsg) }) diff --git a/duplex_http_call.go b/duplex_http_call.go index 4dac0092..439f55c2 100644 --- a/duplex_http_call.go +++ b/duplex_http_call.go @@ -179,7 +179,7 @@ func (d *duplexHTTPCall) CloseRead() error { if d.response == nil { return nil } - if err := discard(d.response.Body); err != nil { + if _, err := discard(d.response.Body); err != nil { _ = d.response.Body.Close() return wrapIfRSTError(err) } diff --git a/envelope.go b/envelope.go index b45f0c48..559ede69 100644 --- a/envelope.go +++ b/envelope.go @@ -199,6 +199,12 @@ func (r *envelopeReader) Unmarshal(message any) *Error { } if env.Flags != 0 && env.Flags != flagEnvelopeCompressed { + // Drain the rest of the stream to ensure there is no extra data. + if n, err := discard(r.reader); err != nil { + return errorf(CodeInternal, "corrupt response: I/O error after end-stream message: %w", err) + } else if n > 0 { + return errorf(CodeInternal, "corrupt response: %d extra bytes after end of stream", n) + } // One of the protocol-specific flags are set, so this is the end of the // stream. Save the message for protocol-specific code to process and // return a sentinel error. Since we've deferred functions to return env's diff --git a/protocol.go b/protocol.go index 8486ca96..a02f24b0 100644 --- a/protocol.go +++ b/protocol.go @@ -283,16 +283,14 @@ func isCommaOrSpace(c rune) bool { return c == ',' || c == ' ' } -func discard(reader io.Reader) error { +func discard(reader io.Reader) (int64, error) { if lr, ok := reader.(*io.LimitedReader); ok { - _, err := io.Copy(io.Discard, lr) - return err + return io.Copy(io.Discard, lr) } // We don't want to get stuck throwing data away forever, so limit how much // we're willing to do here. lr := &io.LimitedReader{R: reader, N: discardLimit} - _, err := io.Copy(io.Discard, lr) - return err + return io.Copy(io.Discard, lr) } // negotiateCompression determines and validates the request compression and diff --git a/protocol_grpc.go b/protocol_grpc.go index ea290931..44dfaf69 100644 --- a/protocol_grpc.go +++ b/protocol_grpc.go @@ -326,7 +326,7 @@ func (g *grpcClient) NewConn( } else { conn.readTrailers = func(_ *grpcUnmarshaler, call *duplexHTTPCall) http.Header { // To access HTTP trailers, we need to read the body to EOF. - _ = discard(call) + _, _ = discard(call) return call.ResponseTrailer() } }