From c6b556204c8a240157df271081099899c42d7c02 Mon Sep 17 00:00:00 2001 From: Joshua Carpeggiani <32605850+joshcarp@users.noreply.github.com> Date: Fri, 18 Nov 2022 11:01:03 -0500 Subject: [PATCH] Bump test coverage to 90% (#395) Adds more tests to 90% test coverage --- client_ext_test.go | 2 + connect_ext_test.go | 418 +++++++++++++++++++++++++++++++++++++++- error_test.go | 1 + interceptor_ext_test.go | 27 +++ 4 files changed, 446 insertions(+), 2 deletions(-) diff --git a/client_ext_test.go b/client_ext_test.go index bfe4e399..4b02554d 100644 --- a/client_ext_test.go +++ b/client_ext_test.go @@ -150,6 +150,7 @@ func (a *assertPeerInterceptor) WrapStreamingClient(next connect.StreamingClient conn := next(ctx, spec) assert.NotZero(a.tb, conn.Peer().Addr) assert.NotZero(a.tb, conn.Peer().Protocol) + assert.NotZero(a.tb, conn.Spec()) return conn } } @@ -158,6 +159,7 @@ func (a *assertPeerInterceptor) WrapStreamingHandler(next connect.StreamingHandl return func(ctx context.Context, conn connect.StreamingHandlerConn) error { assert.NotZero(a.tb, conn.Peer().Addr) assert.NotZero(a.tb, conn.Peer().Protocol) + assert.NotZero(a.tb, conn.Spec()) return next(ctx, conn) } } diff --git a/connect_ext_test.go b/connect_ext_test.go index 5e7f7dc3..13f181cf 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -543,6 +543,44 @@ func TestTimeoutParsing(t *testing.T) { assert.Nil(t, err) } +func TestFailCodec(t *testing.T) { + t.Parallel() + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + server := httptest.NewServer(handler) + defer server.Close() + client := pingv1connect.NewPingServiceClient( + server.Client(), + server.URL, + connect.WithCodec(failCodec{}), + ) + stream := client.CumSum(context.Background()) + err := stream.Send(&pingv1.CumSumRequest{}) + var connectErr *connect.Error + assert.NotNil(t, err) + assert.True(t, errors.As(err, &connectErr)) + assert.Equal(t, connectErr.Code(), connect.CodeInternal) +} + +func TestContextError(t *testing.T) { + t.Parallel() + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + server := httptest.NewServer(handler) + defer server.Close() + client := pingv1connect.NewPingServiceClient( + server.Client(), + server.URL, + ) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + stream := client.CumSum(ctx) + err := stream.Send(nil) + var connectErr *connect.Error + assert.NotNil(t, err) + assert.True(t, errors.As(err, &connectErr)) + assert.Equal(t, connectErr.Code(), connect.CodeCanceled) + assert.False(t, connect.IsWireError(err)) +} + func TestGRPCMarshalStatusError(t *testing.T) { t.Parallel() @@ -699,6 +737,43 @@ func TestBidiRequiresHTTP2(t *testing.T) { ) } +func TestCompressMinBytesClient(t *testing.T) { + t.Parallel() + assertContentType := func(tb testing.TB, text, expect string) { + tb.Helper() + mux := http.NewServeMux() + mux.Handle("/", http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + assert.Equal(tb, request.Header.Get("Content-Encoding"), expect) + })) + server := httptest.NewServer(mux) + tb.Cleanup(server.Close) + _, err := pingv1connect.NewPingServiceClient( + server.Client(), + server.URL, + connect.WithSendGzip(), + connect.WithCompressMinBytes(8), + ).Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{Text: text})) + assert.Nil(tb, err) + } + t.Run("request_uncompressed", func(t *testing.T) { + t.Parallel() + assertContentType(t, "ping", "") + }) + t.Run("request_compressed", func(t *testing.T) { + t.Parallel() + assertContentType(t, "pingping", "gzip") + }) + + t.Run("request_uncompressed", func(t *testing.T) { + t.Parallel() + assertContentType(t, "ping", "") + }) + t.Run("request_compressed", func(t *testing.T) { + t.Parallel() + assertContentType(t, strings.Repeat("ping", 2), "gzip") + }) +} + func TestCompressMinBytes(t *testing.T) { t.Parallel() mux := http.NewServeMux() @@ -1393,6 +1468,312 @@ func TestBidiStreamServerSendsFirstMessage(t *testing.T) { }) } +func TestStreamForServer(t *testing.T) { + t.Parallel() + newPingServer := func(pingServer pingv1connect.PingServiceHandler) (pingv1connect.PingServiceClient, *httptest.Server) { + mux := http.NewServeMux() + mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) + server := httptest.NewUnstartedServer(mux) + server.EnableHTTP2 = true + server.StartTLS() + client := pingv1connect.NewPingServiceClient( + server.Client(), + server.URL, + ) + return client, server + } + t.Run("not-proto-message", func(t *testing.T) { + t.Parallel() + client, server := newPingServer(&pluggablePingServer{ + cumSum: func(ctx context.Context, stream *connect.BidiStream[pingv1.CumSumRequest, pingv1.CumSumResponse]) error { + return stream.Conn().Send("foobar") + }, + }) + t.Cleanup(server.Close) + stream := client.CumSum(context.Background()) + assert.Nil(t, stream.Send(nil)) + _, err := stream.Receive() + assert.NotNil(t, err) + assert.Equal(t, connect.CodeOf(err), connect.CodeInternal) + assert.Nil(t, stream.CloseRequest()) + }) + t.Run("nil-message", func(t *testing.T) { + t.Parallel() + client, server := newPingServer(&pluggablePingServer{ + cumSum: func(ctx context.Context, stream *connect.BidiStream[pingv1.CumSumRequest, pingv1.CumSumResponse]) error { + return stream.Send(nil) + }, + }) + t.Cleanup(server.Close) + stream := client.CumSum(context.Background()) + assert.Nil(t, stream.Send(nil)) + _, err := stream.Receive() + assert.NotNil(t, err) + assert.Equal(t, connect.CodeOf(err), connect.CodeUnknown) + assert.Nil(t, stream.CloseRequest()) + }) + t.Run("get-spec", func(t *testing.T) { + t.Parallel() + client, server := newPingServer(&pluggablePingServer{ + cumSum: func(ctx context.Context, stream *connect.BidiStream[pingv1.CumSumRequest, pingv1.CumSumResponse]) error { + assert.Equal(t, stream.Spec().StreamType, connect.StreamTypeBidi) + assert.Equal(t, stream.Spec().Procedure, "/connect.ping.v1.PingService/CumSum") + assert.False(t, stream.Spec().IsClient) + return nil + }, + }) + t.Cleanup(server.Close) + stream := client.CumSum(context.Background()) + assert.Nil(t, stream.Send(nil)) + assert.Nil(t, stream.CloseRequest()) + }) + t.Run("server-stream", func(t *testing.T) { + t.Parallel() + client, server := newPingServer(&pluggablePingServer{ + countUp: func(ctx context.Context, req *connect.Request[pingv1.CountUpRequest], stream *connect.ServerStream[pingv1.CountUpResponse]) error { + assert.Equal(t, stream.Conn().Spec().StreamType, connect.StreamTypeServer) + assert.Equal(t, stream.Conn().Spec().Procedure, "/connect.ping.v1.PingService/CountUp") + assert.False(t, stream.Conn().Spec().IsClient) + assert.Nil(t, stream.Send(&pingv1.CountUpResponse{Number: 1})) + return nil + }, + }) + t.Cleanup(server.Close) + stream, err := client.CountUp(context.Background(), connect.NewRequest(&pingv1.CountUpRequest{})) + assert.Nil(t, err) + assert.NotNil(t, stream) + assert.Nil(t, stream.Close()) + }) + t.Run("server-stream-send", func(t *testing.T) { + t.Parallel() + client, server := newPingServer(&pluggablePingServer{ + countUp: func(ctx context.Context, req *connect.Request[pingv1.CountUpRequest], stream *connect.ServerStream[pingv1.CountUpResponse]) error { + assert.Nil(t, stream.Send(&pingv1.CountUpResponse{Number: 1})) + return nil + }, + }) + t.Cleanup(server.Close) + stream, err := client.CountUp(context.Background(), connect.NewRequest(&pingv1.CountUpRequest{})) + assert.Nil(t, err) + assert.True(t, stream.Receive()) + msg := stream.Msg() + assert.NotNil(t, msg) + assert.Equal(t, msg.Number, 1) + assert.Nil(t, stream.Close()) + }) + t.Run("server-stream-send-nil", func(t *testing.T) { + t.Parallel() + client, server := newPingServer(&pluggablePingServer{ + countUp: func(ctx context.Context, req *connect.Request[pingv1.CountUpRequest], stream *connect.ServerStream[pingv1.CountUpResponse]) error { + stream.ResponseHeader().Set("foo", "bar") + stream.ResponseTrailer().Set("bas", "blah") + assert.Nil(t, stream.Send(nil)) + return nil + }, + }) + t.Cleanup(server.Close) + stream, err := client.CountUp(context.Background(), connect.NewRequest(&pingv1.CountUpRequest{})) + assert.Nil(t, err) + assert.False(t, stream.Receive()) + headers := stream.ResponseHeader() + assert.NotNil(t, headers) + assert.Equal(t, headers.Get("foo"), "bar") + trailers := stream.ResponseTrailer() + assert.NotNil(t, trailers) + assert.Equal(t, trailers.Get("bas"), "blah") + assert.Nil(t, stream.Close()) + }) + t.Run("client-stream", func(t *testing.T) { + t.Parallel() + client, server := newPingServer(&pluggablePingServer{ + sum: func(ctx context.Context, stream *connect.ClientStream[pingv1.SumRequest]) (*connect.Response[pingv1.SumResponse], error) { + assert.Equal(t, stream.Spec().StreamType, connect.StreamTypeClient) + assert.Equal(t, stream.Spec().Procedure, "/connect.ping.v1.PingService/Sum") + assert.False(t, stream.Spec().IsClient) + assert.True(t, stream.Receive()) + msg := stream.Msg() + assert.NotNil(t, msg) + assert.Equal(t, msg.Number, 1) + return connect.NewResponse(&pingv1.SumResponse{Sum: 1}), nil + }, + }) + t.Cleanup(server.Close) + stream := client.Sum(context.Background()) + assert.Nil(t, stream.Send(&pingv1.SumRequest{Number: 1})) + res, err := stream.CloseAndReceive() + assert.Nil(t, err) + assert.NotNil(t, res) + assert.Equal(t, res.Msg.Sum, 1) + }) + t.Run("client-stream-conn", func(t *testing.T) { + t.Parallel() + client, server := newPingServer(&pluggablePingServer{ + sum: func(ctx context.Context, stream *connect.ClientStream[pingv1.SumRequest]) (*connect.Response[pingv1.SumResponse], error) { + assert.NotNil(t, stream.Conn().Send("not-proto")) + return connect.NewResponse(&pingv1.SumResponse{}), nil + }, + }) + t.Cleanup(server.Close) + stream := client.Sum(context.Background()) + assert.Nil(t, stream.Send(&pingv1.SumRequest{Number: 1})) + res, err := stream.CloseAndReceive() + assert.Nil(t, err) + assert.NotNil(t, res) + }) + t.Run("client-stream-send-msg", func(t *testing.T) { + t.Parallel() + client, server := newPingServer(&pluggablePingServer{ + sum: func(ctx context.Context, stream *connect.ClientStream[pingv1.SumRequest]) (*connect.Response[pingv1.SumResponse], error) { + assert.Nil(t, stream.Conn().Send(&pingv1.SumResponse{Sum: 2})) + return connect.NewResponse(&pingv1.SumResponse{}), nil + }, + }) + t.Cleanup(server.Close) + stream := client.Sum(context.Background()) + assert.Nil(t, stream.Send(&pingv1.SumRequest{Number: 1})) + res, err := stream.CloseAndReceive() + assert.NotNil(t, err) + assert.Equal(t, connect.CodeOf(err), connect.CodeUnknown) + assert.Nil(t, res) + }) +} + +func TestConnectHTTPErrorCodes(t *testing.T) { + t.Parallel() + checkHTTPStatus := func(t *testing.T, connectCode connect.Code, wantHttpStatus int) { + t.Helper() + mux := http.NewServeMux() + pluggableServer := &pluggablePingServer{ + ping: func(_ context.Context, _ *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) { + return nil, connect.NewError(connectCode, errors.New("error")) + }, + } + mux.Handle(pingv1connect.NewPingServiceHandler(pluggableServer)) + server := httptest.NewServer(mux) + t.Cleanup(server.Close) + req, err := http.NewRequestWithContext( + context.Background(), + http.MethodPost, + server.URL+"/"+pingv1connect.PingServiceName+"/Ping", + strings.NewReader("{}"), + ) + assert.Nil(t, err) + req.Header.Set("Content-Type", "application/json") + resp, err := server.Client().Do(req) + assert.Nil(t, err) + assert.Equal(t, wantHttpStatus, resp.StatusCode) + defer resp.Body.Close() + connectClient := pingv1connect.NewPingServiceClient(server.Client(), server.URL) + connectResp, err := connectClient.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{})) + assert.NotNil(t, err) + assert.Nil(t, connectResp) + } + t.Run("connect.CodeCanceled, 408", func(t *testing.T) { + t.Parallel() + checkHTTPStatus(t, connect.CodeCanceled, 408) + }) + t.Run("connect.CodeUnknown, 500", func(t *testing.T) { + t.Parallel() + checkHTTPStatus(t, connect.CodeUnknown, 500) + }) + t.Run("connect.CodeInvalidArgument, 400", func(t *testing.T) { + t.Parallel() + checkHTTPStatus(t, connect.CodeInvalidArgument, 400) + }) + t.Run("connect.CodeDeadlineExceeded, 408", func(t *testing.T) { + t.Parallel() + checkHTTPStatus(t, connect.CodeDeadlineExceeded, 408) + }) + t.Run("connect.CodeNotFound, 404", func(t *testing.T) { + t.Parallel() + checkHTTPStatus(t, connect.CodeNotFound, 404) + }) + t.Run("connect.CodeAlreadyExists, 409", func(t *testing.T) { + t.Parallel() + checkHTTPStatus(t, connect.CodeAlreadyExists, 409) + }) + t.Run("connect.CodePermissionDenied, 403", func(t *testing.T) { + t.Parallel() + checkHTTPStatus(t, connect.CodePermissionDenied, 403) + }) + t.Run("connect.CodeResourceExhausted, 429", func(t *testing.T) { + t.Parallel() + checkHTTPStatus(t, connect.CodeResourceExhausted, 429) + }) + t.Run("connect.CodeFailedPrecondition, 412", func(t *testing.T) { + t.Parallel() + checkHTTPStatus(t, connect.CodeFailedPrecondition, 412) + }) + t.Run("connect.CodeAborted, 409", func(t *testing.T) { + t.Parallel() + checkHTTPStatus(t, connect.CodeAborted, 409) + }) + t.Run("connect.CodeOutOfRange, 400", func(t *testing.T) { + t.Parallel() + checkHTTPStatus(t, connect.CodeOutOfRange, 400) + }) + t.Run("connect.CodeUnimplemented, 404", func(t *testing.T) { + t.Parallel() + checkHTTPStatus(t, connect.CodeUnimplemented, 404) + }) + t.Run("connect.CodeInternal, 500", func(t *testing.T) { + t.Parallel() + checkHTTPStatus(t, connect.CodeInternal, 500) + }) + t.Run("connect.CodeUnavailable, 503", func(t *testing.T) { + t.Parallel() + checkHTTPStatus(t, connect.CodeUnavailable, 503) + }) + t.Run("connect.CodeDataLoss, 500", func(t *testing.T) { + t.Parallel() + checkHTTPStatus(t, connect.CodeDataLoss, 500) + }) + t.Run("connect.CodeUnauthenticated, 401", func(t *testing.T) { + t.Parallel() + checkHTTPStatus(t, connect.CodeUnauthenticated, 401) + }) + t.Run("100, 500", func(t *testing.T) { + t.Parallel() + checkHTTPStatus(t, 100, 500) + }) + // t.Run("0, 500", func(t *testing.T) { //TODO: enable this when + // t.Parallel() + // checkHTTPStatus(t, 0, 500) + // }) +} + +func TestFailCompression(t *testing.T) { + t.Parallel() + mux := http.NewServeMux() + compressorName := "fail" + compressor := func() connect.Compressor { return failCompressor{} } + decompressor := func() connect.Decompressor { return failDecompressor{} } + mux.Handle( + pingv1connect.NewPingServiceHandler( + pingServer{}, + connect.WithCompression(compressorName, decompressor, compressor), + ), + ) + server := httptest.NewUnstartedServer(mux) + server.EnableHTTP2 = true + server.StartTLS() + t.Cleanup(server.Close) + pingclient := pingv1connect.NewPingServiceClient( + server.Client(), + server.URL, + connect.WithAcceptCompression(compressorName, decompressor, compressor), + connect.WithSendCompression(compressorName), + ) + _, err := pingclient.Ping( + context.Background(), + connect.NewRequest(&pingv1.PingRequest{ + Text: "ping", + }), + ) + assert.NotNil(t, err) + assert.Equal(t, connect.CodeOf(err), connect.CodeInternal) +} + func gzipCompressedSize(tb testing.TB, message proto.Message) int { tb.Helper() uncompressed, err := proto.Marshal(message) @@ -1426,8 +1807,10 @@ func (c failCodec) Unmarshal(data []byte, message any) error { type pluggablePingServer struct { pingv1connect.UnimplementedPingServiceHandler - ping func(context.Context, *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) - cumSum func(context.Context, *connect.BidiStream[pingv1.CumSumRequest, pingv1.CumSumResponse]) error + ping func(context.Context, *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) + sum func(context.Context, *connect.ClientStream[pingv1.SumRequest]) (*connect.Response[pingv1.SumResponse], error) + countUp func(context.Context, *connect.Request[pingv1.CountUpRequest], *connect.ServerStream[pingv1.CountUpResponse]) error + cumSum func(context.Context, *connect.BidiStream[pingv1.CumSumRequest, pingv1.CumSumResponse]) error } func (p *pluggablePingServer) Ping( @@ -1437,6 +1820,21 @@ func (p *pluggablePingServer) Ping( return p.ping(ctx, request) } +func (p *pluggablePingServer) Sum( + ctx context.Context, + stream *connect.ClientStream[pingv1.SumRequest], +) (*connect.Response[pingv1.SumResponse], error) { + return p.sum(ctx, stream) +} + +func (p *pluggablePingServer) CountUp( + ctx context.Context, + req *connect.Request[pingv1.CountUpRequest], + stream *connect.ServerStream[pingv1.CountUpResponse], +) error { + return p.countUp(ctx, req, stream) +} + func (p *pluggablePingServer) CumSum( ctx context.Context, stream *connect.BidiStream[pingv1.CumSumRequest, pingv1.CumSumResponse], @@ -1684,3 +2082,19 @@ func newHTTPMiddlewareError() *connect.Error { err.Meta().Set("Middleware-Foo", "bar") return err } + +type failDecompressor struct { + connect.Decompressor +} + +type failCompressor struct{} + +func (failCompressor) Write([]byte) (int, error) { + return 0, errors.New("failCompressor") +} + +func (failCompressor) Close() error { + return errors.New("failCompressor") +} + +func (failCompressor) Reset(io.Writer) {} diff --git a/error_test.go b/error_test.go index c5618e28..be276172 100644 --- a/error_test.go +++ b/error_test.go @@ -38,6 +38,7 @@ func TestErrorNilUnderlying(t *testing.T) { assert.Nil(t, detailErr) err.AddDetail(detail) assert.Equal(t, len(err.Details()), 1) + assert.Equal(t, err.Details()[0].Type(), "google.protobuf.Empty") err.Meta().Set("foo", "bar") assert.Equal(t, err.Meta().Get("foo"), "bar") assert.Equal(t, CodeOf(err), CodeUnknown) diff --git a/interceptor_ext_test.go b/interceptor_ext_test.go index 81ab3d96..78913fc9 100644 --- a/interceptor_ext_test.go +++ b/interceptor_ext_test.go @@ -139,6 +139,33 @@ func TestOnionOrderingEndToEnd(t *testing.T) { assert.Nil(t, responses.Close()) } +func TestEmptyUnaryInterceptorFunc(t *testing.T) { + t.Parallel() + mux := http.NewServeMux() + interceptor := connect.UnaryInterceptorFunc(func(next connect.UnaryFunc) connect.UnaryFunc { + return func(ctx context.Context, request connect.AnyRequest) (connect.AnyResponse, error) { + return next(ctx, request) + } + }) + mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{}, connect.WithInterceptors(interceptor))) + server := httptest.NewServer(mux) + t.Cleanup(server.Close) + connectClient := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithInterceptors(interceptor)) + _, err := connectClient.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{})) + assert.Nil(t, err) + sumStream := connectClient.Sum(context.Background()) + assert.Nil(t, sumStream.Send(&pingv1.SumRequest{Number: 1})) + resp, err := sumStream.CloseAndReceive() + assert.Nil(t, err) + assert.NotNil(t, resp) + countUpStream, err := connectClient.CountUp(context.Background(), connect.NewRequest(&pingv1.CountUpRequest{})) + assert.Nil(t, err) + for countUpStream.Receive() { + assert.NotNil(t, countUpStream.Msg()) + } + assert.Nil(t, countUpStream.Close()) +} + // headerInterceptor makes it easier to write interceptors that inspect or // mutate HTTP headers. It applies the same logic to unary and streaming // procedures, wrapping the send or receive side of the stream as appropriate.