From 8e45b51c55ce1e9a07cee002a5482cda48b799c5 Mon Sep 17 00:00:00 2001 From: Joshua Carpeggiani <32605850+joshcarp@users.noreply.github.com> Date: Fri, 18 Nov 2022 16:32:30 -0500 Subject: [PATCH] Fix panic on zero send from server (#398) - Maps a zero code of `connectWireError` to a CodeUnknown in `connectWireError) asError() *Error ` This could have been done in the unmarshaling to `connectWireError` but I chose to leave that to be exactly what was read off the wire. - Adds nil guard in `validateResponse` for asError - Adds tests for error mapping Fixes: https://github.com/bufbuild/connect-go/issues/396 --- connect_ext_test.go | 44 ++++++++++++++++++++++---------------------- protocol_connect.go | 8 +++++++- 2 files changed, 29 insertions(+), 23 deletions(-) diff --git a/connect_ext_test.go b/connect_ext_test.go index 13f181cf..aa0da553 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -1661,85 +1661,85 @@ func TestConnectHTTPErrorCodes(t *testing.T) { req.Header.Set("Content-Type", "application/json") resp, err := server.Client().Do(req) assert.Nil(t, err) - assert.Equal(t, wantHttpStatus, resp.StatusCode) defer resp.Body.Close() + assert.Equal(t, wantHttpStatus, resp.StatusCode) connectClient := pingv1connect.NewPingServiceClient(server.Client(), server.URL) connectResp, err := connectClient.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{})) assert.NotNil(t, err) assert.Nil(t, connectResp) } - t.Run("connect.CodeCanceled, 408", func(t *testing.T) { + t.Run("CodeCanceled-408", func(t *testing.T) { t.Parallel() checkHTTPStatus(t, connect.CodeCanceled, 408) }) - t.Run("connect.CodeUnknown, 500", func(t *testing.T) { + t.Run("CodeUnknown-500", func(t *testing.T) { t.Parallel() checkHTTPStatus(t, connect.CodeUnknown, 500) }) - t.Run("connect.CodeInvalidArgument, 400", func(t *testing.T) { + t.Run("CodeInvalidArgument-400", func(t *testing.T) { t.Parallel() checkHTTPStatus(t, connect.CodeInvalidArgument, 400) }) - t.Run("connect.CodeDeadlineExceeded, 408", func(t *testing.T) { + t.Run("CodeDeadlineExceeded-408", func(t *testing.T) { t.Parallel() checkHTTPStatus(t, connect.CodeDeadlineExceeded, 408) }) - t.Run("connect.CodeNotFound, 404", func(t *testing.T) { + t.Run("CodeNotFound-404", func(t *testing.T) { t.Parallel() checkHTTPStatus(t, connect.CodeNotFound, 404) }) - t.Run("connect.CodeAlreadyExists, 409", func(t *testing.T) { + t.Run("CodeAlreadyExists-409", func(t *testing.T) { t.Parallel() checkHTTPStatus(t, connect.CodeAlreadyExists, 409) }) - t.Run("connect.CodePermissionDenied, 403", func(t *testing.T) { + t.Run("CodePermissionDenied-403", func(t *testing.T) { t.Parallel() checkHTTPStatus(t, connect.CodePermissionDenied, 403) }) - t.Run("connect.CodeResourceExhausted, 429", func(t *testing.T) { + t.Run("CodeResourceExhausted-429", func(t *testing.T) { t.Parallel() checkHTTPStatus(t, connect.CodeResourceExhausted, 429) }) - t.Run("connect.CodeFailedPrecondition, 412", func(t *testing.T) { + t.Run("CodeFailedPrecondition-412", func(t *testing.T) { t.Parallel() checkHTTPStatus(t, connect.CodeFailedPrecondition, 412) }) - t.Run("connect.CodeAborted, 409", func(t *testing.T) { + t.Run("CodeAborted-409", func(t *testing.T) { t.Parallel() checkHTTPStatus(t, connect.CodeAborted, 409) }) - t.Run("connect.CodeOutOfRange, 400", func(t *testing.T) { + t.Run("CodeOutOfRange-400", func(t *testing.T) { t.Parallel() checkHTTPStatus(t, connect.CodeOutOfRange, 400) }) - t.Run("connect.CodeUnimplemented, 404", func(t *testing.T) { + t.Run("CodeUnimplemented-404", func(t *testing.T) { t.Parallel() checkHTTPStatus(t, connect.CodeUnimplemented, 404) }) - t.Run("connect.CodeInternal, 500", func(t *testing.T) { + t.Run("CodeInternal-500", func(t *testing.T) { t.Parallel() checkHTTPStatus(t, connect.CodeInternal, 500) }) - t.Run("connect.CodeUnavailable, 503", func(t *testing.T) { + t.Run("CodeUnavailable-503", func(t *testing.T) { t.Parallel() checkHTTPStatus(t, connect.CodeUnavailable, 503) }) - t.Run("connect.CodeDataLoss, 500", func(t *testing.T) { + t.Run("CodeDataLoss-500", func(t *testing.T) { t.Parallel() checkHTTPStatus(t, connect.CodeDataLoss, 500) }) - t.Run("connect.CodeUnauthenticated, 401", func(t *testing.T) { + t.Run("CodeUnauthenticated-401", func(t *testing.T) { t.Parallel() checkHTTPStatus(t, connect.CodeUnauthenticated, 401) }) - t.Run("100, 500", func(t *testing.T) { + t.Run("100-500", func(t *testing.T) { t.Parallel() checkHTTPStatus(t, 100, 500) }) - // t.Run("0, 500", func(t *testing.T) { //TODO: enable this when - // t.Parallel() - // checkHTTPStatus(t, 0, 500) - // }) + t.Run("0-500", func(t *testing.T) { + t.Parallel() + checkHTTPStatus(t, 0, 500) + }) } func TestFailCompression(t *testing.T) { diff --git a/protocol_connect.go b/protocol_connect.go index 13d32e23..54858c8c 100644 --- a/protocol_connect.go +++ b/protocol_connect.go @@ -423,6 +423,9 @@ func (cc *connectUnaryClientConn) validateResponse(response *http.Response) *Err ) } serverErr := wireErr.asError() + if serverErr == nil { + return nil + } serverErr.meta = cc.responseHeader.Clone() mergeHeaders(serverErr.meta, cc.responseTrailer) return serverErr @@ -919,9 +922,12 @@ func newConnectWireError(err error) *connectWireError { } func (e *connectWireError) asError() *Error { - if e == nil || e.Code == 0 { + if e == nil { return nil } + if e.Code < minCode || e.Code > maxCode { + e.Code = CodeUnknown + } err := NewError(e.Code, errors.New(e.Message)) err.wireErr = true if len(e.Details) > 0 {