From aea011d6296ee7a23de87f30b6cd3facbf967c7d Mon Sep 17 00:00:00 2001 From: Josh Humphries <2035234+jhump@users.noreply.github.com> Date: Wed, 7 Feb 2024 12:18:32 -0500 Subject: [PATCH 1/2] check response content type --- connect_ext_test.go | 4 + duplex_http_call.go | 5 + error.go | 2 +- protocol_connect.go | 91 ++++++++++++++-- protocol_connect_test.go | 220 +++++++++++++++++++++++++++++++++++++++ protocol_grpc.go | 31 +++++- protocol_grpc_test.go | 123 ++++++++++++++++++++++ 7 files changed, 468 insertions(+), 8 deletions(-) diff --git a/connect_ext_test.go b/connect_ext_test.go index f9e3ce61..6d80302d 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -810,6 +810,7 @@ func TestUnavailableIfHostInvalid(t *testing.T) { func TestBidiRequiresHTTP2(t *testing.T) { t.Parallel() handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/connect+proto") _, err := io.WriteString(w, "hello world") assert.Nil(t, err) }) @@ -841,6 +842,7 @@ func TestCompressMinBytesClient(t *testing.T) { tb.Helper() mux := http.NewServeMux() mux.Handle("/", http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + writer.Header().Set("Content-Type", "application/proto") assert.Equal(tb, request.Header.Get("Content-Encoding"), expect) })) server := memhttptest.NewServer(t, mux) @@ -2231,6 +2233,7 @@ func TestStreamUnexpectedEOF(t *testing.T) { name: "connect_excess_eof", options: []connect.ClientOption{connect.WithProtoJSON()}, handler: func(responseWriter http.ResponseWriter, _ *http.Request) { + responseWriter.Header().Set("Content-Type", "application/connect+json") _, err := responseWriter.Write(head[:]) assert.Nil(t, err) _, err = responseWriter.Write(payload) @@ -2252,6 +2255,7 @@ func TestStreamUnexpectedEOF(t *testing.T) { name: "grpc-web_excess_eof", options: []connect.ClientOption{connect.WithProtoJSON(), connect.WithGRPCWeb()}, handler: func(responseWriter http.ResponseWriter, _ *http.Request) { + responseWriter.Header().Set("Content-Type", "application/grpc-web+json") _, err := responseWriter.Write(head[:]) assert.Nil(t, err) _, err = responseWriter.Write(payload) diff --git a/duplex_http_call.go b/duplex_http_call.go index 38428be9..80f5f2e4 100644 --- a/duplex_http_call.go +++ b/duplex_http_call.go @@ -203,6 +203,11 @@ func (d *duplexHTTPCall) URL() *url.URL { return d.request.URL } +// Method returns the HTTP method for the request (GET or POST). +func (d *duplexHTTPCall) Method() string { + return d.request.Method +} + // SetMethod changes the method of the request before it is sent. func (d *duplexHTTPCall) SetMethod(method string) { d.request.Method = method diff --git a/error.go b/error.go index 4c57aade..2e85ebde 100644 --- a/error.go +++ b/error.go @@ -133,7 +133,7 @@ func NewError(c Code, underlying error) *Error { // This is useful for clients trying to propagate partial failures from // streaming RPCs. Often, these RPCs include error information in their // response messages (for example, [gRPC server reflection] and -// OpenTelemtetry's [OTLP]). Clients propagating these errors up the stack +// OpenTelemetry's [OTLP]). Clients propagating these errors up the stack // should use NewWireError to clarify that the error code, message, and details // (if any) were explicitly sent by the server rather than inferred from a // lower-level networking error or timeout. diff --git a/protocol_connect.go b/protocol_connect.go index 9aaa71ca..ec03ecc9 100644 --- a/protocol_connect.go +++ b/protocol_connect.go @@ -511,6 +511,21 @@ func (cc *connectUnaryClientConn) validateResponse(response *http.Response) *Err } cc.responseTrailer[strings.TrimPrefix(k, connectUnaryTrailerPrefix)] = v } + err := connectValidateUnaryResponseContentType( + cc.marshaler.codec.Name(), + cc.duplexCall.Method(), + response.StatusCode, + response.Status, + getHeaderCanonical(response.Header, headerContentType), + ) + if err != nil { + if IsNotModifiedError(err) { + // Allow access to response headers for this kind of error. + // RFC 9110 doesn't allow trailers on 304s, so we only need to include headers. + err.meta = cc.responseHeader.Clone() + } + return err + } compression := getHeaderCanonical(response.Header, connectUnaryHeaderCompression) if compression != "" && compression != compressionIdentity && @@ -522,12 +537,7 @@ func (cc *connectUnaryClientConn) validateResponse(response *http.Response) *Err cc.compressionPools.CommaSeparatedNames(), ) } - if response.StatusCode == http.StatusNotModified && cc.Spec().IdempotencyLevel == IdempotencyNoSideEffects { - serverErr := NewWireError(CodeUnknown, errNotModifiedClient) - // RFC 9110 doesn't allow trailers on 304s, so we only need to include headers. - serverErr.meta = cc.responseHeader.Clone() - return serverErr - } else if response.StatusCode != http.StatusOK { + if response.StatusCode != http.StatusOK { unmarshaler := connectUnaryUnmarshaler{ reader: response.Body, compressionPool: cc.compressionPools.Get(compression), @@ -643,6 +653,14 @@ func (cc *connectStreamingClientConn) validateResponse(response *http.Response) if response.StatusCode != http.StatusOK { return errorf(connectHTTPToCode(response.StatusCode), "HTTP status %v", response.Status) } + err := connectValidateStreamResponseContentType( + cc.codec.Name(), + cc.spec.StreamType, + getHeaderCanonical(response.Header, headerContentType), + ) + if err != nil { + return err + } compression := getHeaderCanonical(response.Header, connectStreamingHeaderCompression) if compression != "" && compression != compressionIdentity && @@ -1324,3 +1342,64 @@ func queryValueReader(data string, base64Encoded bool) io.Reader { } return strings.NewReader(data) } + +func connectValidateUnaryResponseContentType( + requestCodecName string, + httpMethod string, + statusCode int, + statusMsg string, + responseContentType string, +) *Error { + if statusCode != http.StatusOK { + if statusCode == http.StatusNotModified && httpMethod == http.MethodGet { + return NewWireError(CodeUnknown, errNotModifiedClient) + } + // Error responses must be JSON-encoded. + if responseContentType == connectUnaryContentTypePrefix+codecNameJSON || + responseContentType == connectUnaryContentTypePrefix+codecNameJSONCharsetUTF8 { + return nil + } + return NewError( + connectHTTPToCode(statusCode), + errors.New(statusMsg), + ) + } + // Normal responses must have valid content-type that indicates same codec as the request. + responseCodecName := connectCodecFromContentType( + StreamTypeUnary, + responseContentType, + ) + if responseCodecName == requestCodecName { + return nil + } + // HACK: We likely want a better way to handle the optional "charset" parameter + // for application/json, instead of hard-coding. But this suffices for now. + if (responseCodecName == codecNameJSON && requestCodecName == codecNameJSONCharsetUTF8) || + (responseCodecName == codecNameJSONCharsetUTF8 && requestCodecName == codecNameJSON) { + // Both are JSON + return nil + } + return errorf( + CodeInternal, + "invalid content-type: %q; expecting %q", + responseContentType, + connectUnaryContentTypePrefix+requestCodecName, + ) +} + +func connectValidateStreamResponseContentType(requestCodecName string, streamType StreamType, responseContentType string) *Error { + // Responses must have valid content-type that indicates same codec as the request. + responseCodecName := connectCodecFromContentType( + streamType, + responseContentType, + ) + if responseCodecName != requestCodecName { + return errorf( + CodeInternal, + "invalid content-type: %q; expecting %q", + responseContentType, + connectStreamingContentTypePrefix+requestCodecName, + ) + } + return nil +} diff --git a/protocol_connect_test.go b/protocol_connect_test.go index 9b3571cf..cc53bcfd 100644 --- a/protocol_connect_test.go +++ b/protocol_connect_test.go @@ -17,6 +17,7 @@ package connect import ( "bytes" "encoding/json" + "fmt" "net/http" "strings" "testing" @@ -93,3 +94,222 @@ func TestConnectEndOfResponseCanonicalTrailers(t *testing.T) { assert.Equal(t, unmarshaler.Trailer().Values("Mixed-Canonical"), []string{"b", "b"}) assert.Equal(t, unmarshaler.Trailer().Values("Canonical-Header"), []string{"c"}) } + +func TestConnectValidateUnaryResponseContentType(t *testing.T) { + t.Parallel() + testCases := []struct { + codecName string + get bool + statusCode int + responseContentType string + expectCode Code + expectBadContentType bool + expectNotModified bool + }{ + // Allowed content-types for OK responses. + { + codecName: codecNameProto, + statusCode: http.StatusOK, + responseContentType: "application/proto", + }, + { + codecName: codecNameJSON, + statusCode: http.StatusOK, + responseContentType: "application/json", + }, + { + codecName: codecNameJSON, + statusCode: http.StatusOK, + responseContentType: "application/json; charset=utf-8", + }, + { + codecName: codecNameJSONCharsetUTF8, + statusCode: http.StatusOK, + responseContentType: "application/json", + }, + { + codecName: codecNameJSONCharsetUTF8, + statusCode: http.StatusOK, + responseContentType: "application/json; charset=utf-8", + }, + // Allowed content-types for error responses. + { + codecName: codecNameProto, + statusCode: http.StatusNotFound, + responseContentType: "application/json", + }, + { + codecName: codecNameProto, + statusCode: http.StatusBadRequest, + responseContentType: "application/json; charset=utf-8", + }, + { + codecName: codecNameJSON, + statusCode: http.StatusInternalServerError, + responseContentType: "application/json", + }, + { + codecName: codecNameJSON, + statusCode: http.StatusPreconditionFailed, + responseContentType: "application/json; charset=utf-8", + }, + // 304 Not Modified for GET request gets a special error, regardless of content-type + { + codecName: codecNameProto, + get: true, + statusCode: http.StatusNotModified, + responseContentType: "application/json", + expectCode: CodeUnknown, + expectNotModified: true, + }, + { + codecName: codecNameJSON, + get: true, + statusCode: http.StatusNotModified, + responseContentType: "application/json", + expectCode: CodeUnknown, + expectNotModified: true, + }, + // OK status, invalid content-type + { + codecName: codecNameProto, + statusCode: http.StatusOK, + responseContentType: "application/proto; charset=utf-8", + expectCode: CodeInternal, + expectBadContentType: true, + }, + { + codecName: codecNameProto, + statusCode: http.StatusOK, + responseContentType: "application/json", + expectCode: CodeInternal, + expectBadContentType: true, + }, + { + codecName: codecNameJSON, + statusCode: http.StatusOK, + responseContentType: "application/proto", + expectCode: CodeInternal, + expectBadContentType: true, + }, + { + codecName: codecNameJSON, + statusCode: http.StatusOK, + responseContentType: "some/garbage", + expectCode: CodeInternal, + expectBadContentType: true, + }, + // Error status, invalid content-type, returns code based on HTTP status code + { + codecName: codecNameProto, + statusCode: http.StatusNotFound, + responseContentType: "application/proto", + expectCode: connectHTTPToCode(http.StatusNotFound), + }, + { + codecName: codecNameJSON, + statusCode: http.StatusBadRequest, + responseContentType: "some/garbage", + expectCode: connectHTTPToCode(http.StatusBadRequest), + }, + { + codecName: codecNameJSON, + statusCode: http.StatusTooManyRequests, + responseContentType: "some/garbage", + expectCode: connectHTTPToCode(http.StatusTooManyRequests), + }, + } + for _, testCase := range testCases { + testCase := testCase + httpMethod := http.MethodPost + if testCase.get { + httpMethod = http.MethodGet + } + testCaseName := fmt.Sprintf("%s_%s->%d_%s", httpMethod, testCase.codecName, testCase.statusCode, testCase.responseContentType) + t.Run(testCaseName, func(t *testing.T) { + t.Parallel() + err := connectValidateUnaryResponseContentType( + testCase.codecName, + httpMethod, + testCase.statusCode, + http.StatusText(testCase.statusCode), + testCase.responseContentType, + ) + if testCase.expectCode == 0 { + assert.Nil(t, err) + } else if assert.NotNil(t, err) { + assert.Equal(t, CodeOf(err), testCase.expectCode) + if testCase.expectNotModified { + assert.ErrorIs(t, err, errNotModified) + } else if testCase.expectBadContentType { + assert.True(t, strings.Contains(err.Message(), fmt.Sprintf("invalid content-type: %q; expecting", testCase.responseContentType))) + } else { + assert.Equal(t, err.Message(), http.StatusText(testCase.statusCode)) + } + } + }) + } +} + +func TestConnectValidateStreamResponseContentType(t *testing.T) { + t.Parallel() + testCases := []struct { + codecName string + responseContentType string + expectErr bool + }{ + // Allowed content-types + { + codecName: codecNameProto, + responseContentType: "application/connect+proto", + }, + { + codecName: codecNameJSON, + responseContentType: "application/connect+json", + }, + // Disallowed content-types + { + codecName: codecNameProto, + responseContentType: "application/proto", + expectErr: true, + }, + { + codecName: codecNameJSON, + responseContentType: "application/json", + expectErr: true, + }, + { + codecName: codecNameJSON, + responseContentType: "application/json; charset=utf-8", + expectErr: true, + }, + { + codecName: codecNameJSON, + responseContentType: "application/connect+json; charset=utf-8", + expectErr: true, + }, + { + codecName: codecNameProto, + responseContentType: "some/garbage", + expectErr: true, + }, + } + for _, testCase := range testCases { + testCase := testCase + testCaseName := fmt.Sprintf("%s->%s", testCase.codecName, testCase.responseContentType) + t.Run(testCaseName, func(t *testing.T) { + t.Parallel() + err := connectValidateStreamResponseContentType( + testCase.codecName, + StreamTypeServer, + testCase.responseContentType, + ) + if !testCase.expectErr { + assert.Nil(t, err) + } else if assert.NotNil(t, err) { + assert.Equal(t, CodeOf(err), CodeInternal) + assert.True(t, strings.Contains(err.Message(), fmt.Sprintf("invalid content-type: %q; expecting", testCase.responseContentType))) + } + }) + } +} diff --git a/protocol_grpc.go b/protocol_grpc.go index 6dc9012f..5b61a856 100644 --- a/protocol_grpc.go +++ b/protocol_grpc.go @@ -135,7 +135,7 @@ func (*grpcHandler) SetTimeout(request *http.Request) (context.Context, context. return ctx, cancel, nil } -func (g *grpcHandler) CanHandlePayload(request *http.Request, contentType string) bool { +func (g *grpcHandler) CanHandlePayload(_ *http.Request, contentType string) bool { _, ok := g.accept[contentType] return ok } @@ -422,6 +422,8 @@ func (cc *grpcClientConn) validateResponse(response *http.Response) *Error { cc.responseHeader, cc.responseTrailer, cc.compressionPools, + cc.unmarshaler.web, + cc.marshaler.codec.Name(), cc.protobuf, ); err != nil { return err @@ -644,11 +646,16 @@ func grpcValidateResponse( response *http.Response, header, trailer http.Header, availableCompressors readOnlyCompressionPools, + web bool, + codecName string, protobuf Codec, ) *Error { if response.StatusCode != http.StatusOK { return errorf(grpcHTTPToCode(response.StatusCode), "HTTP status %v", response.Status) } + if err := grpcValidateResponseContentType(web, codecName, getHeaderCanonical(response.Header, headerContentType)); err != nil { + return err + } if compression := getHeaderCanonical(response.Header, grpcHeaderCompression); compression != "" && compression != compressionIdentity && !availableCompressors.Contains(compression) { @@ -998,3 +1005,25 @@ func validateHex(input string) error { } return nil } + +func grpcValidateResponseContentType(web bool, requestCodecName string, responseContentType string) *Error { + // Responses must have valid content-type that indicates same codec as the request. + bare, prefix := grpcContentTypeDefault, grpcContentTypePrefix + if web { + bare, prefix = grpcWebContentTypeDefault, grpcWebContentTypePrefix + } + if responseContentType == prefix+requestCodecName || + (requestCodecName == codecNameProto && responseContentType == bare) { + return nil + } + expectedContentType := bare + if requestCodecName != codecNameProto { + expectedContentType = prefix + requestCodecName + } + return errorf( + CodeInternal, + "invalid content-type: %q; expecting %q", + responseContentType, + expectedContentType, + ) +} diff --git a/protocol_grpc_test.go b/protocol_grpc_test.go index 5bef77fe..68c0feba 100644 --- a/protocol_grpc_test.go +++ b/protocol_grpc_test.go @@ -16,6 +16,7 @@ package connect import ( "errors" + "fmt" "math" "net/http" "net/http/httptest" @@ -230,3 +231,125 @@ func BenchmarkGRPCTimeoutEncoding(b *testing.B) { } } } + +func TestGRPCValidateResponseContentType(t *testing.T) { + t.Parallel() + testCases := []struct { + web bool + codecName string + responseContentType string + expectErr bool + }{ + // Allowed content-types + { + codecName: codecNameProto, + responseContentType: "application/grpc", + }, + { + codecName: codecNameProto, + responseContentType: "application/grpc+proto", + }, + { + codecName: codecNameJSON, + responseContentType: "application/grpc+json", + }, + { + codecName: codecNameProto, + web: true, + responseContentType: "application/grpc-web", + }, + { + codecName: codecNameProto, + web: true, + responseContentType: "application/grpc-web+proto", + }, + { + codecName: codecNameJSON, + web: true, + responseContentType: "application/grpc-web+json", + }, + // Disallowed content-types + { + codecName: codecNameProto, + responseContentType: "application/proto", + expectErr: true, + }, + { + codecName: codecNameProto, + responseContentType: "application/grpc-web", + expectErr: true, + }, + { + codecName: codecNameProto, + responseContentType: "application/grpc-web+proto", + expectErr: true, + }, + { + codecName: codecNameJSON, + responseContentType: "application/json", + expectErr: true, + }, + { + codecName: codecNameJSON, + responseContentType: "application/grpc-web+json", + expectErr: true, + }, + { + codecName: codecNameProto, + web: true, + responseContentType: "application/proto", + expectErr: true, + }, + { + codecName: codecNameProto, + web: true, + responseContentType: "application/grpc", + expectErr: true, + }, + { + codecName: codecNameProto, + web: true, + responseContentType: "application/grpc+proto", + expectErr: true, + }, + { + codecName: codecNameJSON, + web: true, + responseContentType: "application/json", + expectErr: true, + }, + { + codecName: codecNameJSON, + web: true, + responseContentType: "application/grpc+json", + expectErr: true, + }, + { + codecName: codecNameProto, + responseContentType: "some/garbage", + expectErr: true, + }, + } + for _, testCase := range testCases { + testCase := testCase + protocol := ProtocolGRPC + if testCase.web { + protocol = ProtocolGRPCWeb + } + testCaseName := fmt.Sprintf("%s_%s->%s", protocol, testCase.codecName, testCase.responseContentType) + t.Run(testCaseName, func(t *testing.T) { + t.Parallel() + err := grpcValidateResponseContentType( + testCase.web, + testCase.codecName, + testCase.responseContentType, + ) + if !testCase.expectErr { + assert.Nil(t, err) + } else if assert.NotNil(t, err) { + assert.Equal(t, CodeOf(err), CodeInternal) + assert.True(t, strings.Contains(err.Message(), fmt.Sprintf("invalid content-type: %q; expecting", testCase.responseContentType))) + } + }) + } +} From 4d4f97affb3b63062abdcc83ac9b9035127b0d18 Mon Sep 17 00:00:00 2001 From: Josh Humphries <2035234+jhump@users.noreply.github.com> Date: Wed, 7 Feb 2024 15:34:18 -0500 Subject: [PATCH 2/2] make gocritic happy --- protocol_connect_test.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/protocol_connect_test.go b/protocol_connect_test.go index cc53bcfd..6b524870 100644 --- a/protocol_connect_test.go +++ b/protocol_connect_test.go @@ -239,11 +239,12 @@ func TestConnectValidateUnaryResponseContentType(t *testing.T) { assert.Nil(t, err) } else if assert.NotNil(t, err) { assert.Equal(t, CodeOf(err), testCase.expectCode) - if testCase.expectNotModified { + switch { + case testCase.expectNotModified: assert.ErrorIs(t, err, errNotModified) - } else if testCase.expectBadContentType { + case testCase.expectBadContentType: assert.True(t, strings.Contains(err.Message(), fmt.Sprintf("invalid content-type: %q; expecting", testCase.responseContentType))) - } else { + default: assert.Equal(t, err.Message(), http.StatusText(testCase.statusCode)) } }