From 7b3b3447f754588d2d5cadfd77f4695097dd7b52 Mon Sep 17 00:00:00 2001 From: Edward McFarlane <3036610+emcfarlane@users.noreply.github.com> Date: Tue, 19 Mar 2024 15:25:46 -0400 Subject: [PATCH] Restrict metadata headers in error propagation (#711) This PR addresses issues when propagating errors from a client back to a handler. On the client side connect errors will contain all response headers: transport (`Content-Type`, `Content-Length`, etc), protocol and application headers. These could break the transport when trying to re-encode the error or leak sensitive information between services. For any wire errors (errors decoded from a client response) we now disable meta propagation. For other errors we now also restrict the headers propagated. --- connect_ext_test.go | 132 ++++++++++++++++++++++++++++++++++++++++++++ error.go | 9 +++ error_writer.go | 4 +- header.go | 40 ++++++++++++++ protocol.go | 1 + protocol_connect.go | 13 +++-- protocol_grpc.go | 4 +- 7 files changed, 193 insertions(+), 10 deletions(-) diff --git a/connect_ext_test.go b/connect_ext_test.go index 5f2076d2..00acd1ff 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -43,6 +43,7 @@ import ( "connectrpc.com/connect/internal/memhttp/memhttptest" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoregistry" + "google.golang.org/protobuf/types/known/wrapperspb" ) const errorMessage = "oh no" @@ -542,6 +543,137 @@ func TestConcurrentStreams(t *testing.T) { done.Wait() } +func TestErrorHeaderPropagation(t *testing.T) { + t.Parallel() + newError := func(testname string, isWire bool) *connect.Error { + err := connect.NewError(connect.CodeInvalidArgument, errors.New(testname)) + if isWire { + err = connect.NewWireError(connect.CodeInvalidArgument, errors.New(testname)) + } + msgDetail := &wrapperspb.StringValue{Value: "server details"} + errDetail, derr := connect.NewErrorDetail(msgDetail) + if assert.Nil(t, derr) { + err.AddDetail(errDetail) + } + err.Meta().Set("Content-Length", "1337") + err.Meta().Set("Content-Type", "application/xml") + err.Meta().Set("Accept-Encoding", "bogus") + err.Meta().Set("Date", "Thu, 01 Jan 1970 00:00:00 GMT") + err.Meta().Set("Grpc-Status", "0") + // Set custom headers. + err.Meta().Set("X-Test", testname) + err.Meta()["x-test-case"] = []string{testname} + return err + } + pingServer := &pluggablePingServer{ + ping: func(ctx context.Context, request *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) { + return nil, newError(request.Header().Get("X-Test"), request.Header().Get("X-Test-Is-Wire") == "true") + }, + cumSum: func(ctx context.Context, stream *connect.BidiStream[pingv1.CumSumRequest, pingv1.CumSumResponse]) error { + return newError(stream.RequestHeader().Get("X-Test"), stream.RequestHeader().Get("X-Test-Is-Wire") == "true") + }, + } + mux := http.NewServeMux() + mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) + server := memhttptest.NewServer(t, mux) + + assertError := func(t *testing.T, err error, allowCustomHeaders bool) { + t.Helper() + var connectErr *connect.Error + if !assert.True(t, errors.As(err, &connectErr)) { + return + } + assert.Equal(t, connectErr.Code(), connect.CodeInvalidArgument) + assert.Equal(t, connectErr.Message(), t.Name()) + details := connectErr.Details() + if assert.Equal(t, len(details), 1) { + detailMsg, err := details[0].Value() + if !assert.Nil(t, err) { + return + } + serverDetails, ok := detailMsg.(*wrapperspb.StringValue) + if !assert.True(t, ok) { + return + } + assert.Equal(t, serverDetails.Value, "server details") + } + meta := connectErr.Meta() + assert.NotEqual(t, meta.Values("Content-Length"), []string{"1337"}) + assert.NotEqual(t, meta.Values("Accept-Encoding"), []string{"bogus"}) + assert.NotEqual(t, meta.Values("Content-Type"), []string{"application/xml"}) + assert.NotEqual(t, meta.Values("Content-Length"), []string{"1337"}) + assert.NotEqual(t, meta.Values("Date"), []string{"Thu, 01 Jan 1970 00:00:00 GMT"}) + if allowCustomHeaders { + assert.Equal(t, meta.Values("x-test-case"), []string{t.Name()}) + assert.Equal(t, meta.Values("X-Test"), []string{t.Name()}) + } else { + assert.Equal(t, meta.Values("x-test-case"), []string(nil)) + assert.Equal(t, meta.Values("X-Test"), []string(nil)) + } + } + testServices := func(t *testing.T, client pingv1connect.PingServiceClient) { + t.Helper() + t.Run("unary", func(t *testing.T) { + request := connect.NewRequest(&pingv1.PingRequest{}) + request.Header().Set("X-Test", t.Name()) + _, err := client.Ping(context.Background(), request) + if !assert.NotNil(t, err) { + return + } + assertError(t, err, true /* allowCustomHeaders */) + t.Run("wire", func(t *testing.T) { + request := connect.NewRequest(&pingv1.PingRequest{}) + request.Header().Set("X-Test", t.Name()) + request.Header().Set("X-Test-Is-Wire", "true") + _, err := client.Ping(context.Background(), request) + if !assert.NotNil(t, err) { + return + } + assertError(t, err, false /* allowCustomHeaders */) + }) + }) + t.Run("bidi", func(t *testing.T) { + stream := client.CumSum(context.Background()) + stream.RequestHeader().Set("X-Test", t.Name()) + if err := stream.Send(nil); err != nil { + t.Fatal(err) + } + _, err := stream.Receive() + if !assert.NotNil(t, err) { + return + } + assertError(t, err, true /* allowCustomHeaders */) + t.Run("wire", func(t *testing.T) { + stream := client.CumSum(context.Background()) + stream.RequestHeader().Set("X-Test", t.Name()) + stream.RequestHeader().Set("X-Test-Is-Wire", "true") + if err := stream.Send(nil); err != nil { + t.Fatal(err) + } + _, err := stream.Receive() + if !assert.NotNil(t, err) { + return + } + }) + }) + } + t.Run("connect", func(t *testing.T) { + t.Parallel() + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) + testServices(t, client) + }) + t.Run("grpc", func(t *testing.T) { + t.Parallel() + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) + testServices(t, client) + }) + t.Run("grpc-web", func(t *testing.T) { + t.Parallel() + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) + testServices(t, client) + }) +} + func TestHeaderBasic(t *testing.T) { t.Parallel() const ( diff --git a/error.go b/error.go index 26544d95..f461d1bb 100644 --- a/error.go +++ b/error.go @@ -158,6 +158,10 @@ func NewWireError(c Code, underlying error) *Error { // Clients may find this useful when deciding how to propagate errors. For // example, an RPC-to-HTTP proxy might expose a server-sent CodeUnknown as an // HTTP 500 but a client-synthesized CodeUnknown as a 503. +// +// Handlers will strip [Error.Meta] headers propagated from wire errors to avoid +// leaking response headers. To propagate headers recreate the error as a +// non-wire error. func IsWireError(err error) bool { se := new(Error) if !errors.As(err, &se) { @@ -229,6 +233,11 @@ func (e *Error) AddDetail(d *ErrorDetail) { // or a block of in-body metadata, depending on the protocol in use and whether // or not the handler has already written messages to the stream. // +// Protocol-specific headers and trailers may be removed to avoid breaking +// protocol semantics. For example, Content-Length and Content-Type headers +// won't be propagated. See the documentation for each protocol for more +// datails. +// // When clients receive errors, the metadata contains the union of the HTTP // headers and the protocol-specific trailers (either HTTP trailers or in-body // metadata). diff --git a/error_writer.go b/error_writer.go index 466c3b8e..58ce3c42 100644 --- a/error_writer.go +++ b/error_writer.go @@ -127,8 +127,8 @@ func (w *ErrorWriter) Write(response http.ResponseWriter, request *http.Request, } func (w *ErrorWriter) writeConnectUnary(response http.ResponseWriter, err error) error { - if connectErr, ok := asError(err); ok { - mergeHeaders(response.Header(), connectErr.meta) + if connectErr, ok := asError(err); ok && !connectErr.wireErr { + mergeMetadataHeaders(response.Header(), connectErr.meta) } response.WriteHeader(connectCodeToHTTP(CodeOf(err))) data, marshalErr := json.Marshal(newConnectWireError(err)) diff --git a/header.go b/header.go index 2e827309..f3c7cacd 100644 --- a/header.go +++ b/header.go @@ -57,6 +57,46 @@ func mergeHeaders(into, from http.Header) { } } +// mergeMetdataHeaders merges the metadata headers from the "from" header into +// the "into" header. It skips over non metadata headers that should not be +// propagated from the server to the client. +func mergeMetadataHeaders(into, from http.Header) { + for key, vals := range from { + if len(vals) == 0 { + // For response trailers, net/http will pre-populate entries + // with nil values based on the "Trailer" header. But if there + // are no actual values for those keys, we skip them. + continue + } + switch http.CanonicalHeaderKey(key) { + case headerContentType, + headerContentLength, + headerContentEncoding, + headerHost, + headerUserAgent, + headerTrailer, + headerDate: + // HTTP headers. + case connectUnaryHeaderAcceptCompression, + connectUnaryTrailerPrefix, + connectStreamingHeaderCompression, + connectStreamingHeaderAcceptCompression, + connectHeaderTimeout, + connectHeaderProtocolVersion: + // Connect headers. + case grpcHeaderCompression, + grpcHeaderAcceptCompression, + grpcHeaderTimeout, + grpcHeaderStatus, + grpcHeaderMessage, + grpcHeaderDetails: + // gRPC headers. + default: + into[key] = append(into[key], vals...) + } + } +} + // getHeaderCanonical is a shortcut for Header.Get() which // bypasses the CanonicalMIMEHeaderKey operation when we // know the key is already in canonical form. diff --git a/protocol.go b/protocol.go index eb1984f0..9add614c 100644 --- a/protocol.go +++ b/protocol.go @@ -41,6 +41,7 @@ const ( headerHost = "Host" headerUserAgent = "User-Agent" headerTrailer = "Trailer" + headerDate = "Date" discardLimit = 1024 * 1024 * 4 // 4MiB ) diff --git a/protocol_connect.go b/protocol_connect.go index b53b52eb..d478d634 100644 --- a/protocol_connect.go +++ b/protocol_connect.go @@ -534,10 +534,12 @@ func (cc *connectUnaryClientConn) validateResponse(response *http.Response) *Err cc.compressionPools.CommaSeparatedNames(), ) } + cc.unmarshaler.compressionPool = cc.compressionPools.Get(compression) if response.StatusCode != http.StatusOK { unmarshaler := connectUnaryUnmarshaler{ + ctx: cc.unmarshaler.ctx, reader: response.Body, - compressionPool: cc.compressionPools.Get(compression), + compressionPool: cc.unmarshaler.compressionPool, bufferPool: cc.bufferPool, } var wireErr connectWireError @@ -559,7 +561,6 @@ func (cc *connectUnaryClientConn) validateResponse(response *http.Response) *Err mergeHeaders(serverErr.meta, cc.responseTrailer) return serverErr } - cc.unmarshaler.compressionPool = cc.compressionPools.Get(compression) return nil } @@ -765,8 +766,8 @@ func (hc *connectUnaryHandlerConn) writeResponseHeader(err error) { header[headerVary] = append(header[headerVary], connectUnaryHeaderAcceptCompression) } if err != nil { - if connectErr, ok := asError(err); ok { - mergeHeaders(header, connectErr.meta) + if connectErr, ok := asError(err); ok && !connectErr.wireErr { + mergeMetadataHeaders(header, connectErr.meta) } } for k, v := range hc.responseTrailer { @@ -850,8 +851,8 @@ func (m *connectStreamingMarshaler) MarshalEndStream(err error, trailer http.Hea end := &connectEndStreamMessage{Trailer: trailer} if err != nil { end.Error = newConnectWireError(err) - if connectErr, ok := asError(err); ok { - mergeHeaders(end.Trailer, connectErr.meta) + if connectErr, ok := asError(err); ok && !connectErr.wireErr { + mergeMetadataHeaders(end.Trailer, connectErr.meta) } } data, marshalErr := json.Marshal(end) diff --git a/protocol_grpc.go b/protocol_grpc.go index 5a58cf42..97f9c81e 100644 --- a/protocol_grpc.go +++ b/protocol_grpc.go @@ -859,8 +859,8 @@ func grpcErrorToTrailer(trailer http.Header, protobuf Codec, err error) { ) return } - if connectErr, ok := asError(err); ok { - mergeHeaders(trailer, connectErr.meta) + if connectErr, ok := asError(err); ok && !connectErr.wireErr { + mergeMetadataHeaders(trailer, connectErr.meta) } setHeaderCanonical(trailer, grpcHeaderStatus, code) setHeaderCanonical(trailer, grpcHeaderMessage, grpcPercentEncode(status.GetMessage()))