From 8ee38cf2eb4d1389c793b36d2eede516bc48ad9e Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Wed, 17 Apr 2024 13:12:42 -0400 Subject: [PATCH 1/2] Fix multiple header writes --- connect.go | 7 ++++--- protocol_connect.go | 10 +++++----- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/connect.go b/connect.go index bd8490e5..ecc0f7ce 100644 --- a/connect.go +++ b/connect.go @@ -431,9 +431,10 @@ func receiveUnaryMessage[T any](conn receiveConn, initializer maybeInitializer, if err := initializer.maybe(conn.Spec(), &msg2); err != nil { return nil, err } - if err := conn.Receive(&msg2); err == nil { - return nil, NewError(CodeUnimplemented, fmt.Errorf("unary %s has multiple messages", what)) - } else if err != nil && !errors.Is(err, io.EOF) { + if err := conn.Receive(&msg2); !errors.Is(err, io.EOF) { + if err == nil { + err = NewError(CodeUnimplemented, fmt.Errorf("unary %s has multiple messages", what)) + } return nil, err } return &msg, nil diff --git a/protocol_connect.go b/protocol_connect.go index d478d634..f50d4141 100644 --- a/protocol_connect.go +++ b/protocol_connect.go @@ -709,11 +709,11 @@ func (hc *connectUnaryHandlerConn) RequestHeader() http.Header { } func (hc *connectUnaryHandlerConn) Send(msg any) error { - hc.wroteBody = true - hc.writeResponseHeader(nil /* error */) + hc.mergeResponseHeader(nil /* error */) if err := hc.marshaler.Marshal(msg); err != nil { return err } + hc.wroteBody = true return nil // must be a literal nil: nil *Error is a non-nil error } @@ -727,7 +727,7 @@ func (hc *connectUnaryHandlerConn) ResponseTrailer() http.Header { func (hc *connectUnaryHandlerConn) Close(err error) error { if !hc.wroteBody { - hc.writeResponseHeader(err) + hc.mergeResponseHeader(err) // If the handler received a GET request and the resource hasn't changed, // return a 304. if len(hc.peer.Query) > 0 && IsNotModifiedError(err) { @@ -735,7 +735,7 @@ func (hc *connectUnaryHandlerConn) Close(err error) error { return hc.request.Body.Close() } } - if err == nil { + if err == nil || hc.wroteBody { return hc.request.Body.Close() } // In unary Connect, errors always use application/json. @@ -757,7 +757,7 @@ func (hc *connectUnaryHandlerConn) getHTTPMethod() string { return hc.request.Method } -func (hc *connectUnaryHandlerConn) writeResponseHeader(err error) { +func (hc *connectUnaryHandlerConn) mergeResponseHeader(err error) { header := hc.responseWriter.Header() if hc.request.Method == http.MethodGet { // The response content varies depending on the compression that the client From 50f762a8ac8f5dd51926343ae54e5fde5610ba95 Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Wed, 17 Apr 2024 13:50:35 -0400 Subject: [PATCH 2/2] Move bool to write --- protocol_connect.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/protocol_connect.go b/protocol_connect.go index f50d4141..e3c5e4a5 100644 --- a/protocol_connect.go +++ b/protocol_connect.go @@ -686,7 +686,6 @@ type connectUnaryHandlerConn struct { marshaler connectUnaryMarshaler unmarshaler connectUnaryUnmarshaler responseTrailer http.Header - wroteBody bool } func (hc *connectUnaryHandlerConn) Spec() Spec { @@ -713,7 +712,6 @@ func (hc *connectUnaryHandlerConn) Send(msg any) error { if err := hc.marshaler.Marshal(msg); err != nil { return err } - hc.wroteBody = true return nil // must be a literal nil: nil *Error is a non-nil error } @@ -726,7 +724,7 @@ func (hc *connectUnaryHandlerConn) ResponseTrailer() http.Header { } func (hc *connectUnaryHandlerConn) Close(err error) error { - if !hc.wroteBody { + if !hc.marshaler.wroteHeader { hc.mergeResponseHeader(err) // If the handler received a GET request and the resource hasn't changed, // return a 304. @@ -735,7 +733,7 @@ func (hc *connectUnaryHandlerConn) Close(err error) error { return hc.request.Body.Close() } } - if err == nil || hc.wroteBody { + if err == nil || hc.marshaler.wroteHeader { return hc.request.Body.Close() } // In unary Connect, errors always use application/json. @@ -923,6 +921,7 @@ type connectUnaryMarshaler struct { bufferPool *bufferPool header http.Header sendMaxBytes int + wroteHeader bool } func (m *connectUnaryMarshaler) Marshal(message any) *Error { @@ -961,6 +960,7 @@ func (m *connectUnaryMarshaler) Marshal(message any) *Error { } func (m *connectUnaryMarshaler) write(data []byte) *Error { + m.wroteHeader = true payload := bytes.NewReader(data) if _, err := m.sender.Send(payload); err != nil { err = wrapIfContextError(err)