Skip to content

Commit

Permalink
Add helper to write RPC errors from HTTP middleware (#337)
Browse files Browse the repository at this point in the history
Add a protocol-aware helper type to write RPC errors from net/http middleware. 
The new ErrorWriter struct accepts HandlerOptions, so that it's fully aware of
the supported Codecs and RPC protocols, and it supports gRPC, gRPC-Web,
and Connect clients.

Co-authored-by: Akshay Shah <[email protected]>
  • Loading branch information
rhbuf and akshayjshah authored Aug 29, 2022
1 parent 98d4580 commit edbc6ba
Show file tree
Hide file tree
Showing 4 changed files with 299 additions and 7 deletions.
3 changes: 3 additions & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ issues:
- linters: [varnamelen]
path: cmd/protoc-gen-connect-go
text: "parameter name 'g' is too short"
# Thorough error logging and timeout config make this example unreadably long.
- linters: [errcheck, gosec]
path: error_writer_example_test.go
# It should be crystal clear that Connect uses plain *http.Clients.
- linters: [revive, stylecheck]
path: client_example_test.go
Expand Down
66 changes: 59 additions & 7 deletions connect_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,12 @@ const errorMessage = "oh no"
// client doesn't set a header, and the server sets headers and trailers on the
// response.
const (
headerValue = "some header value"
trailerValue = "some trailer value"
clientHeader = "Connect-Client-Header"
handlerHeader = "Connect-Handler-Header"
handlerTrailer = "Connect-Handler-Trailer"
headerValue = "some header value"
trailerValue = "some trailer value"
clientHeader = "Connect-Client-Header"
handlerHeader = "Connect-Handler-Header"
handlerTrailer = "Connect-Handler-Trailer"
clientMiddlewareErrorHeader = "Connect-Trigger-HTTP-Error"
)

func TestServer(t *testing.T) {
Expand Down Expand Up @@ -289,6 +290,19 @@ func TestServer(t *testing.T) {
})
}
testErrors := func(t *testing.T, client pingv1connect.PingServiceClient) { //nolint:thelper
assertIsHTTPMiddlewareError := func(tb testing.TB, err error) {
tb.Helper()
assert.NotNil(tb, err)
var connectErr *connect.Error
assert.True(tb, errors.As(err, &connectErr))
expect := newHTTPMiddlewareError()
assert.Equal(tb, connectErr.Code(), expect.Code())
assert.Equal(tb, connectErr.Message(), expect.Message())
for k, v := range expect.Meta() {
assert.Equal(tb, connectErr.Meta().Values(k), v)
}
assert.Equal(tb, len(connectErr.Details()), len(expect.Details()))
}
t.Run("errors", func(t *testing.T) {
request := connect.NewRequest(&pingv1.FailRequest{
Code: int32(connect.CodeResourceExhausted),
Expand All @@ -307,6 +321,20 @@ func TestServer(t *testing.T) {
assert.Equal(t, connectErr.Meta().Values(handlerHeader), []string{headerValue})
assert.Equal(t, connectErr.Meta().Values(handlerTrailer), []string{trailerValue})
})
t.Run("middleware_errors_unary", func(t *testing.T) {
request := connect.NewRequest(&pingv1.PingRequest{})
request.Header().Set(clientMiddlewareErrorHeader, headerValue)
_, err := client.Ping(context.Background(), request)
assertIsHTTPMiddlewareError(t, err)
})
t.Run("middleware_errors_streaming", func(t *testing.T) {
request := connect.NewRequest(&pingv1.CountUpRequest{Number: 10})
request.Header().Set(clientMiddlewareErrorHeader, headerValue)
stream, err := client.CountUp(context.Background(), request)
assert.Nil(t, err)
assert.False(t, stream.Receive())
assertIsHTTPMiddlewareError(t, stream.Err())
})
}
testMatrix := func(t *testing.T, server *httptest.Server, bidi bool) { //nolint:thelper
run := func(t *testing.T, opts ...connect.ClientOption) {
Expand Down Expand Up @@ -368,9 +396,27 @@ func TestServer(t *testing.T) {
}

mux := http.NewServeMux()
mux.Handle(pingv1connect.NewPingServiceHandler(
pingRoute, pingHandler := pingv1connect.NewPingServiceHandler(
pingServer{checkMetadata: true},
))
)
errorWriter := connect.NewErrorWriter()
// Add some net/http middleware to the ping service so we can also exercise ErrorWriter.
mux.Handle(pingRoute, http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) {
if request.Header.Get(clientMiddlewareErrorHeader) != "" {
defer request.Body.Close()
if _, err := io.Copy(io.Discard, request.Body); err != nil {
t.Errorf("drain request body: %v", err)
}
if !errorWriter.IsSupported(request) {
t.Errorf("ErrorWriter doesn't support Content-Type %q", request.Header.Get("Content-Type"))
}
if err := errorWriter.Write(response, request, newHTTPMiddlewareError()); err != nil {
t.Errorf("send RPC error from HTTP middleware: %v", err)
}
return
}
pingHandler.ServeHTTP(response, request)
}))

t.Run("http1", func(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -1450,3 +1496,9 @@ func (l *trimTrailerWriter) removeTrailers() {
}
}
}

func newHTTPMiddlewareError() *connect.Error {
err := connect.NewError(connect.CodeResourceExhausted, errors.New("error from HTTP middleware"))
err.Meta().Set("Middleware-Foo", "bar")
return err
}
168 changes: 168 additions & 0 deletions error_writer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
// Copyright 2021-2022 Buf Technologies, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package connect

import (
"encoding/json"
"fmt"
"net/http"
"strings"
)

// An ErrorWriter writes errors to an [http.ResponseWriter] in the format
// expected by an RPC client. This is especially useful in server-side net/http
// middleware, where you may wish to handle requests from RPC and non-RPC
// clients with the same code.
//
// ErrorWriters are safe to use concurrently.
type ErrorWriter struct {
bufferPool *bufferPool
protobuf Codec
allContentTypes map[string]struct{}
grpcContentTypes map[string]struct{}
grpcWebContentTypes map[string]struct{}
unaryConnectContentTypes map[string]struct{}
streamingConnectContentTypes map[string]struct{}
}

// NewErrorWriter constructs an ErrorWriter. To properly recognize supported
// RPC Content-Types in net/http middleware, you must pass the same
// HandlerOptions to NewErrorWriter and any wrapped Connect handlers.
func NewErrorWriter(opts ...HandlerOption) *ErrorWriter {
config := newHandlerConfig("", opts)
writer := &ErrorWriter{
bufferPool: config.BufferPool,
protobuf: newReadOnlyCodecs(config.Codecs).Protobuf(),
allContentTypes: make(map[string]struct{}),
grpcContentTypes: make(map[string]struct{}),
grpcWebContentTypes: make(map[string]struct{}),
unaryConnectContentTypes: make(map[string]struct{}),
streamingConnectContentTypes: make(map[string]struct{}),
}
for name := range config.Codecs {
unary := connectContentTypeFromCodecName(StreamTypeUnary, name)
writer.allContentTypes[unary] = struct{}{}
writer.unaryConnectContentTypes[unary] = struct{}{}
streaming := connectContentTypeFromCodecName(StreamTypeBidi, name)
writer.streamingConnectContentTypes[streaming] = struct{}{}
writer.allContentTypes[streaming] = struct{}{}
}
if config.HandleGRPC {
writer.grpcContentTypes[grpcContentTypeDefault] = struct{}{}
writer.allContentTypes[grpcContentTypeDefault] = struct{}{}
for name := range config.Codecs {
ct := grpcContentTypeFromCodecName(false /* web */, name)
writer.grpcContentTypes[ct] = struct{}{}
writer.allContentTypes[ct] = struct{}{}
}
}
if config.HandleGRPCWeb {
writer.grpcWebContentTypes[grpcWebContentTypeDefault] = struct{}{}
writer.allContentTypes[grpcWebContentTypeDefault] = struct{}{}
for name := range config.Codecs {
ct := grpcContentTypeFromCodecName(true /* web */, name)
writer.grpcWebContentTypes[ct] = struct{}{}
writer.allContentTypes[ct] = struct{}{}
}
}
return writer
}

// IsSupported checks whether a request is using one of the ErrorWriter's
// supported RPC protocols.
func (w *ErrorWriter) IsSupported(request *http.Request) bool {
ctype := request.Header.Get(headerContentType)
_, ok := w.allContentTypes[ctype]
return ok
}

// Write an error, using the format appropriate for the RPC protocol in use.
// Callers should first use IsSupported to verify that the request is using one
// of the ErrorWriter's supported RPC protocols.
//
// Write does not read or close the request body.
func (w *ErrorWriter) Write(response http.ResponseWriter, request *http.Request, err error) error {
ctype := request.Header.Get(headerContentType)
if _, ok := w.unaryConnectContentTypes[ctype]; ok {
// Unary errors are always JSON.
response.Header().Set(headerContentType, connectUnaryContentTypeJSON)
return w.writeConnectUnary(response, err)
}
if _, ok := w.streamingConnectContentTypes[ctype]; ok {
response.Header().Set(headerContentType, ctype)
return w.writeConnectStreaming(response, err)
}
if _, ok := w.grpcContentTypes[ctype]; ok {
response.Header().Set(headerContentType, ctype)
return w.writeGRPC(response, err)
}
if _, ok := w.grpcWebContentTypes[ctype]; ok {
response.Header().Set(headerContentType, ctype)
return w.writeGRPCWeb(response, err)
}
return fmt.Errorf("unsupported Content-Type %q", ctype)
}

func (w *ErrorWriter) writeConnectUnary(response http.ResponseWriter, err error) error {
if connectErr, ok := asError(err); ok {
mergeHeaders(response.Header(), connectErr.meta)
}
response.WriteHeader(connectCodeToHTTP(CodeOf(err)))
data, marshalErr := json.Marshal(newConnectWireError(err))
if marshalErr != nil {
return fmt.Errorf("marshal error: %w", marshalErr)
}
_, writeErr := response.Write(data)
return writeErr
}

func (w *ErrorWriter) writeConnectStreaming(response http.ResponseWriter, err error) error {
response.WriteHeader(http.StatusOK)
marshaler := &connectStreamingMarshaler{
envelopeWriter: envelopeWriter{
writer: response,
bufferPool: w.bufferPool,
},
}
// MarshalEndStream returns *Error: check return value to avoid typed nils.
if marshalErr := marshaler.MarshalEndStream(err, make(http.Header)); marshalErr != nil {
return marshalErr
}
return nil
}

func (w *ErrorWriter) writeGRPC(response http.ResponseWriter, err error) error {
trailers := make(http.Header, 2) // need space for at least code & message
grpcErrorToTrailer(w.bufferPool, trailers, w.protobuf, err)
// To make net/http reliably send trailers without a body, we must set the
// Trailers header rather than using http.TrailerPrefix. See
// https://github.com/golang/go/issues/54723.
keys := make([]string, 0, len(trailers))
for k := range trailers {
keys = append(keys, k)
}
response.Header().Set("Trailer", strings.Join(keys, ","))
response.WriteHeader(http.StatusOK)
mergeHeaders(response.Header(), trailers)
return nil
}

func (w *ErrorWriter) writeGRPCWeb(response http.ResponseWriter, err error) error {
// This is a trailers-only response. To match the behavior of Envoy and
// protocol_grpc.go, put the trailers in the HTTP headers.
grpcErrorToTrailer(w.bufferPool, response.Header(), w.protobuf, err)
response.WriteHeader(http.StatusOK)
return nil
}
69 changes: 69 additions & 0 deletions error_writer_example_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Copyright 2021-2022 Buf Technologies, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package connect_test

import (
"errors"
"io"
"log"
"net/http"

"github.com/bufbuild/connect-go"
)

// NewHelloHandler is an example HTTP handler. In a real application, it might
// handle RPCs, requests for HTML, or anything else.
func NewHelloHandler() http.Handler {
return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) {
io.WriteString(response, "Hello, world!")
})
}

// NewAuthenticatedHandler is an example of middleware that works with both RPC
// and non-RPC clients.
func NewAuthenticatedHandler(handler http.Handler) http.Handler {
errorWriter := connect.NewErrorWriter()
return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) {
// Dummy authentication logic.
if request.Header.Get("Token") == "super-secret" {
handler.ServeHTTP(response, request)
return
}
defer request.Body.Close()
defer io.Copy(io.Discard, request.Body)
if errorWriter.IsSupported(request) {
// Send a protocol-appropriate error to RPC clients, so that they receive
// the right code, message, and any metadata or error details.
unauthenticated := connect.NewError(connect.CodeUnauthenticated, errors.New("invalid token"))
errorWriter.Write(response, request, unauthenticated)
} else {
// Send an error to non-RPC clients.
response.WriteHeader(http.StatusUnauthorized)
io.WriteString(response, "invalid token")
}
})
}

func ExampleErrorWriter() {
mux := http.NewServeMux()
mux.Handle("/", NewHelloHandler())
srv := &http.Server{
Addr: ":8080",
Handler: NewAuthenticatedHandler(mux),
}
if err := srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
log.Fatalln(err)
}
}

0 comments on commit edbc6ba

Please sign in to comment.