From fd9d60a4e3693430d09e9d7894765e41e30522d9 Mon Sep 17 00:00:00 2001 From: Edward McFarlane <3036610+emcfarlane@users.noreply.github.com> Date: Wed, 17 Apr 2024 16:18:26 -0400 Subject: [PATCH] Fix multiple header writes for connect unary on error (#726) On error a connect unary handler will call `Close(err)` to send the final error. For non-nil errors this will always attempt to encode the header status by calling `WriteHeader`. If the body has already been written this triggers the following log via `http.Server`: ``` http: superfluous response.WriteHeader call from connectrpc.com/connect.(*connectUnaryHandlerConn).Close (protocol_connect.go:743) ``` The common case for this superfluous log is when a message send is interrupted, due to a context cancel or other write error, and the error is then attempting to re-encode the headers and body. This PR now moves the `wroteBody` check to a `wroteHeader` check on the unmarshaller. When any payload is sent, including a nil field for header only, we now stop attempting to encode any following errors, as it would be superfluous. --- connect.go | 7 ++++--- protocol_connect.go | 14 +++++++------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/connect.go b/connect.go index bd8490e5..ecc0f7ce 100644 --- a/connect.go +++ b/connect.go @@ -431,9 +431,10 @@ func receiveUnaryMessage[T any](conn receiveConn, initializer maybeInitializer, if err := initializer.maybe(conn.Spec(), &msg2); err != nil { return nil, err } - if err := conn.Receive(&msg2); err == nil { - return nil, NewError(CodeUnimplemented, fmt.Errorf("unary %s has multiple messages", what)) - } else if err != nil && !errors.Is(err, io.EOF) { + if err := conn.Receive(&msg2); !errors.Is(err, io.EOF) { + if err == nil { + err = NewError(CodeUnimplemented, fmt.Errorf("unary %s has multiple messages", what)) + } return nil, err } return &msg, nil diff --git a/protocol_connect.go b/protocol_connect.go index d478d634..e3c5e4a5 100644 --- a/protocol_connect.go +++ b/protocol_connect.go @@ -686,7 +686,6 @@ type connectUnaryHandlerConn struct { marshaler connectUnaryMarshaler unmarshaler connectUnaryUnmarshaler responseTrailer http.Header - wroteBody bool } func (hc *connectUnaryHandlerConn) Spec() Spec { @@ -709,8 +708,7 @@ func (hc *connectUnaryHandlerConn) RequestHeader() http.Header { } func (hc *connectUnaryHandlerConn) Send(msg any) error { - hc.wroteBody = true - hc.writeResponseHeader(nil /* error */) + hc.mergeResponseHeader(nil /* error */) if err := hc.marshaler.Marshal(msg); err != nil { return err } @@ -726,8 +724,8 @@ func (hc *connectUnaryHandlerConn) ResponseTrailer() http.Header { } func (hc *connectUnaryHandlerConn) Close(err error) error { - if !hc.wroteBody { - hc.writeResponseHeader(err) + if !hc.marshaler.wroteHeader { + hc.mergeResponseHeader(err) // If the handler received a GET request and the resource hasn't changed, // return a 304. if len(hc.peer.Query) > 0 && IsNotModifiedError(err) { @@ -735,7 +733,7 @@ func (hc *connectUnaryHandlerConn) Close(err error) error { return hc.request.Body.Close() } } - if err == nil { + if err == nil || hc.marshaler.wroteHeader { return hc.request.Body.Close() } // In unary Connect, errors always use application/json. @@ -757,7 +755,7 @@ func (hc *connectUnaryHandlerConn) getHTTPMethod() string { return hc.request.Method } -func (hc *connectUnaryHandlerConn) writeResponseHeader(err error) { +func (hc *connectUnaryHandlerConn) mergeResponseHeader(err error) { header := hc.responseWriter.Header() if hc.request.Method == http.MethodGet { // The response content varies depending on the compression that the client @@ -923,6 +921,7 @@ type connectUnaryMarshaler struct { bufferPool *bufferPool header http.Header sendMaxBytes int + wroteHeader bool } func (m *connectUnaryMarshaler) Marshal(message any) *Error { @@ -961,6 +960,7 @@ func (m *connectUnaryMarshaler) Marshal(message any) *Error { } func (m *connectUnaryMarshaler) write(data []byte) *Error { + m.wroteHeader = true payload := bytes.NewReader(data) if _, err := m.sender.Send(payload); err != nil { err = wrapIfContextError(err)