diff --git a/.golangci.yml b/.golangci.yml index db3c6bb7..44146370 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -75,6 +75,9 @@ issues: - linters: [varnamelen] path: cmd/protoc-gen-connect-go text: "parameter name 'g' is too short" + # Thorough error logging and timeout config make this example unreadably long. + - linters: [errcheck, gosec] + path: error_writer_example_test.go # It should be crystal clear that Connect uses plain *http.Clients. - linters: [revive, stylecheck] path: client_example_test.go diff --git a/connect_ext_test.go b/connect_ext_test.go index 895bdf22..34f301e1 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -44,11 +44,12 @@ const errorMessage = "oh no" // client doesn't set a header, and the server sets headers and trailers on the // response. const ( - headerValue = "some header value" - trailerValue = "some trailer value" - clientHeader = "Connect-Client-Header" - handlerHeader = "Connect-Handler-Header" - handlerTrailer = "Connect-Handler-Trailer" + headerValue = "some header value" + trailerValue = "some trailer value" + clientHeader = "Connect-Client-Header" + handlerHeader = "Connect-Handler-Header" + handlerTrailer = "Connect-Handler-Trailer" + clientMiddlewareErrorHeader = "Connect-Trigger-HTTP-Error" ) func TestServer(t *testing.T) { @@ -289,6 +290,19 @@ func TestServer(t *testing.T) { }) } testErrors := func(t *testing.T, client pingv1connect.PingServiceClient) { //nolint:thelper + assertIsHTTPMiddlewareError := func(tb testing.TB, err error) { + tb.Helper() + assert.NotNil(tb, err) + var connectErr *connect.Error + assert.True(tb, errors.As(err, &connectErr)) + expect := newHTTPMiddlewareError() + assert.Equal(tb, connectErr.Code(), expect.Code()) + assert.Equal(tb, connectErr.Message(), expect.Message()) + for k, v := range expect.Meta() { + assert.Equal(tb, connectErr.Meta().Values(k), v) + } + assert.Equal(tb, len(connectErr.Details()), len(expect.Details())) + } t.Run("errors", func(t *testing.T) { request := connect.NewRequest(&pingv1.FailRequest{ Code: int32(connect.CodeResourceExhausted), @@ -307,6 +321,20 @@ func TestServer(t *testing.T) { assert.Equal(t, connectErr.Meta().Values(handlerHeader), []string{headerValue}) assert.Equal(t, connectErr.Meta().Values(handlerTrailer), []string{trailerValue}) }) + t.Run("middleware_errors_unary", func(t *testing.T) { + request := connect.NewRequest(&pingv1.PingRequest{}) + request.Header().Set(clientMiddlewareErrorHeader, headerValue) + _, err := client.Ping(context.Background(), request) + assertIsHTTPMiddlewareError(t, err) + }) + t.Run("middleware_errors_streaming", func(t *testing.T) { + request := connect.NewRequest(&pingv1.CountUpRequest{Number: 10}) + request.Header().Set(clientMiddlewareErrorHeader, headerValue) + stream, err := client.CountUp(context.Background(), request) + assert.Nil(t, err) + assert.False(t, stream.Receive()) + assertIsHTTPMiddlewareError(t, stream.Err()) + }) } testMatrix := func(t *testing.T, server *httptest.Server, bidi bool) { //nolint:thelper run := func(t *testing.T, opts ...connect.ClientOption) { @@ -368,9 +396,27 @@ func TestServer(t *testing.T) { } mux := http.NewServeMux() - mux.Handle(pingv1connect.NewPingServiceHandler( + pingRoute, pingHandler := pingv1connect.NewPingServiceHandler( pingServer{checkMetadata: true}, - )) + ) + errorWriter := connect.NewErrorWriter() + // Add some net/http middleware to the ping service so we can also exercise ErrorWriter. + mux.Handle(pingRoute, http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) { + if request.Header.Get(clientMiddlewareErrorHeader) != "" { + defer request.Body.Close() + if _, err := io.Copy(io.Discard, request.Body); err != nil { + t.Errorf("drain request body: %v", err) + } + if !errorWriter.IsSupported(request) { + t.Errorf("ErrorWriter doesn't support Content-Type %q", request.Header.Get("Content-Type")) + } + if err := errorWriter.Write(response, request, newHTTPMiddlewareError()); err != nil { + t.Errorf("send RPC error from HTTP middleware: %v", err) + } + return + } + pingHandler.ServeHTTP(response, request) + })) t.Run("http1", func(t *testing.T) { t.Parallel() @@ -1450,3 +1496,9 @@ func (l *trimTrailerWriter) removeTrailers() { } } } + +func newHTTPMiddlewareError() *connect.Error { + err := connect.NewError(connect.CodeResourceExhausted, errors.New("error from HTTP middleware")) + err.Meta().Set("Middleware-Foo", "bar") + return err +} diff --git a/error_writer.go b/error_writer.go new file mode 100644 index 00000000..efca9da1 --- /dev/null +++ b/error_writer.go @@ -0,0 +1,168 @@ +// Copyright 2021-2022 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package connect + +import ( + "encoding/json" + "fmt" + "net/http" + "strings" +) + +// An ErrorWriter writes errors to an [http.ResponseWriter] in the format +// expected by an RPC client. This is especially useful in server-side net/http +// middleware, where you may wish to handle requests from RPC and non-RPC +// clients with the same code. +// +// ErrorWriters are safe to use concurrently. +type ErrorWriter struct { + bufferPool *bufferPool + protobuf Codec + allContentTypes map[string]struct{} + grpcContentTypes map[string]struct{} + grpcWebContentTypes map[string]struct{} + unaryConnectContentTypes map[string]struct{} + streamingConnectContentTypes map[string]struct{} +} + +// NewErrorWriter constructs an ErrorWriter. To properly recognize supported +// RPC Content-Types in net/http middleware, you must pass the same +// HandlerOptions to NewErrorWriter and any wrapped Connect handlers. +func NewErrorWriter(opts ...HandlerOption) *ErrorWriter { + config := newHandlerConfig("", opts) + writer := &ErrorWriter{ + bufferPool: config.BufferPool, + protobuf: newReadOnlyCodecs(config.Codecs).Protobuf(), + allContentTypes: make(map[string]struct{}), + grpcContentTypes: make(map[string]struct{}), + grpcWebContentTypes: make(map[string]struct{}), + unaryConnectContentTypes: make(map[string]struct{}), + streamingConnectContentTypes: make(map[string]struct{}), + } + for name := range config.Codecs { + unary := connectContentTypeFromCodecName(StreamTypeUnary, name) + writer.allContentTypes[unary] = struct{}{} + writer.unaryConnectContentTypes[unary] = struct{}{} + streaming := connectContentTypeFromCodecName(StreamTypeBidi, name) + writer.streamingConnectContentTypes[streaming] = struct{}{} + writer.allContentTypes[streaming] = struct{}{} + } + if config.HandleGRPC { + writer.grpcContentTypes[grpcContentTypeDefault] = struct{}{} + writer.allContentTypes[grpcContentTypeDefault] = struct{}{} + for name := range config.Codecs { + ct := grpcContentTypeFromCodecName(false /* web */, name) + writer.grpcContentTypes[ct] = struct{}{} + writer.allContentTypes[ct] = struct{}{} + } + } + if config.HandleGRPCWeb { + writer.grpcWebContentTypes[grpcWebContentTypeDefault] = struct{}{} + writer.allContentTypes[grpcWebContentTypeDefault] = struct{}{} + for name := range config.Codecs { + ct := grpcContentTypeFromCodecName(true /* web */, name) + writer.grpcWebContentTypes[ct] = struct{}{} + writer.allContentTypes[ct] = struct{}{} + } + } + return writer +} + +// IsSupported checks whether a request is using one of the ErrorWriter's +// supported RPC protocols. +func (w *ErrorWriter) IsSupported(request *http.Request) bool { + ctype := request.Header.Get(headerContentType) + _, ok := w.allContentTypes[ctype] + return ok +} + +// Write an error, using the format appropriate for the RPC protocol in use. +// Callers should first use IsSupported to verify that the request is using one +// of the ErrorWriter's supported RPC protocols. +// +// Write does not read or close the request body. +func (w *ErrorWriter) Write(response http.ResponseWriter, request *http.Request, err error) error { + ctype := request.Header.Get(headerContentType) + if _, ok := w.unaryConnectContentTypes[ctype]; ok { + // Unary errors are always JSON. + response.Header().Set(headerContentType, connectUnaryContentTypeJSON) + return w.writeConnectUnary(response, err) + } + if _, ok := w.streamingConnectContentTypes[ctype]; ok { + response.Header().Set(headerContentType, ctype) + return w.writeConnectStreaming(response, err) + } + if _, ok := w.grpcContentTypes[ctype]; ok { + response.Header().Set(headerContentType, ctype) + return w.writeGRPC(response, err) + } + if _, ok := w.grpcWebContentTypes[ctype]; ok { + response.Header().Set(headerContentType, ctype) + return w.writeGRPCWeb(response, err) + } + return fmt.Errorf("unsupported Content-Type %q", ctype) +} + +func (w *ErrorWriter) writeConnectUnary(response http.ResponseWriter, err error) error { + if connectErr, ok := asError(err); ok { + mergeHeaders(response.Header(), connectErr.meta) + } + response.WriteHeader(connectCodeToHTTP(CodeOf(err))) + data, marshalErr := json.Marshal(newConnectWireError(err)) + if marshalErr != nil { + return fmt.Errorf("marshal error: %w", marshalErr) + } + _, writeErr := response.Write(data) + return writeErr +} + +func (w *ErrorWriter) writeConnectStreaming(response http.ResponseWriter, err error) error { + response.WriteHeader(http.StatusOK) + marshaler := &connectStreamingMarshaler{ + envelopeWriter: envelopeWriter{ + writer: response, + bufferPool: w.bufferPool, + }, + } + // MarshalEndStream returns *Error: check return value to avoid typed nils. + if marshalErr := marshaler.MarshalEndStream(err, make(http.Header)); marshalErr != nil { + return marshalErr + } + return nil +} + +func (w *ErrorWriter) writeGRPC(response http.ResponseWriter, err error) error { + trailers := make(http.Header, 2) // need space for at least code & message + grpcErrorToTrailer(w.bufferPool, trailers, w.protobuf, err) + // To make net/http reliably send trailers without a body, we must set the + // Trailers header rather than using http.TrailerPrefix. See + // https://github.com/golang/go/issues/54723. + keys := make([]string, 0, len(trailers)) + for k := range trailers { + keys = append(keys, k) + } + response.Header().Set("Trailer", strings.Join(keys, ",")) + response.WriteHeader(http.StatusOK) + mergeHeaders(response.Header(), trailers) + return nil +} + +func (w *ErrorWriter) writeGRPCWeb(response http.ResponseWriter, err error) error { + // This is a trailers-only response. To match the behavior of Envoy and + // protocol_grpc.go, put the trailers in the HTTP headers. + grpcErrorToTrailer(w.bufferPool, response.Header(), w.protobuf, err) + response.WriteHeader(http.StatusOK) + return nil +} diff --git a/error_writer_example_test.go b/error_writer_example_test.go new file mode 100644 index 00000000..1c26ab64 --- /dev/null +++ b/error_writer_example_test.go @@ -0,0 +1,69 @@ +// Copyright 2021-2022 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package connect_test + +import ( + "errors" + "io" + "log" + "net/http" + + "github.com/bufbuild/connect-go" +) + +// NewHelloHandler is an example HTTP handler. In a real application, it might +// handle RPCs, requests for HTML, or anything else. +func NewHelloHandler() http.Handler { + return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) { + io.WriteString(response, "Hello, world!") + }) +} + +// NewAuthenticatedHandler is an example of middleware that works with both RPC +// and non-RPC clients. +func NewAuthenticatedHandler(handler http.Handler) http.Handler { + errorWriter := connect.NewErrorWriter() + return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) { + // Dummy authentication logic. + if request.Header.Get("Token") == "super-secret" { + handler.ServeHTTP(response, request) + return + } + defer request.Body.Close() + defer io.Copy(io.Discard, request.Body) + if errorWriter.IsSupported(request) { + // Send a protocol-appropriate error to RPC clients, so that they receive + // the right code, message, and any metadata or error details. + unauthenticated := connect.NewError(connect.CodeUnauthenticated, errors.New("invalid token")) + errorWriter.Write(response, request, unauthenticated) + } else { + // Send an error to non-RPC clients. + response.WriteHeader(http.StatusUnauthorized) + io.WriteString(response, "invalid token") + } + }) +} + +func ExampleErrorWriter() { + mux := http.NewServeMux() + mux.Handle("/", NewHelloHandler()) + srv := &http.Server{ + Addr: ":8080", + Handler: NewAuthenticatedHandler(mux), + } + if err := srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + log.Fatalln(err) + } +}