diff --git a/.golangci.yml b/.golangci.yml index 48828844..fc627292 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -34,8 +34,8 @@ linters: enable-all: true disable: - cyclop # covered by gocyclo - - depguard # unnecessary for small libraries - deadcode # abandoned + - depguard # unnecessary for small libraries - exhaustivestruct # replaced by exhaustruct - funlen # rely on code review to limit function length - gocognit # dubious "cognitive overhead" quantification diff --git a/codec_test.go b/codec_test.go index 7859be63..ce8fb92c 100644 --- a/codec_test.go +++ b/codec_test.go @@ -128,7 +128,7 @@ func TestStableCodec(t *testing.T) { func TestJSONCodec(t *testing.T) { t.Parallel() - codec := &protoJSONCodec{name: "json"} + codec := &protoJSONCodec{name: codecNameJSON} t.Run("success", func(t *testing.T) { t.Parallel() diff --git a/error_writer.go b/error_writer.go index a7878031..629918ea 100644 --- a/error_writer.go +++ b/error_writer.go @@ -45,6 +45,7 @@ type ErrorWriter struct { grpcWebContentTypes map[string]struct{} unaryConnectContentTypes map[string]struct{} streamingConnectContentTypes map[string]struct{} + requireConnectProtocolHeader bool } // NewErrorWriter constructs an ErrorWriter. To properly recognize supported @@ -60,6 +61,7 @@ func NewErrorWriter(opts ...HandlerOption) *ErrorWriter { grpcWebContentTypes: make(map[string]struct{}), unaryConnectContentTypes: make(map[string]struct{}), streamingConnectContentTypes: make(map[string]struct{}), + requireConnectProtocolHeader: config.RequireConnectProtocolHeader, } for name := range config.Codecs { unary := connectContentTypeFromCodecName(StreamTypeUnary, name) @@ -87,9 +89,17 @@ func NewErrorWriter(opts ...HandlerOption) *ErrorWriter { func (w *ErrorWriter) classifyRequest(request *http.Request) protocolType { ctype := canonicalizeContentType(getHeaderCanonical(request.Header, headerContentType)) if _, ok := w.unaryConnectContentTypes[ctype]; ok { + if err := connectCheckProtocolVersion(request, w.requireConnectProtocolHeader); err != nil { + return unknownProtocol + } return connectUnaryProtocol } if _, ok := w.streamingConnectContentTypes[ctype]; ok { + // Streaming ignores the requireConnectProtocolHeader option as the + // Content-Type is enough to determine the protocol. + if err := connectCheckProtocolVersion(request, false /* required */); err != nil { + return unknownProtocol + } return connectStreamProtocol } if _, ok := w.grpcContentTypes[ctype]; ok { diff --git a/error_writer_test.go b/error_writer_test.go new file mode 100644 index 00000000..0b3be022 --- /dev/null +++ b/error_writer_test.go @@ -0,0 +1,55 @@ +// Copyright 2021-2024 The Connect Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package connect + +import ( + "net/http" + "net/http/httptest" + "testing" + + "connectrpc.com/connect/internal/assert" +) + +func TestErrorWriter(t *testing.T) { + t.Parallel() + + t.Run("RequireConnectProtocolHeader", func(t *testing.T) { + t.Parallel() + writer := NewErrorWriter(WithRequireConnectProtocolHeader()) + + t.Run("Unary", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "http://localhost", nil) + req.Header.Set("Content-Type", connectUnaryContentTypePrefix+codecNameJSON) + assert.False(t, writer.IsSupported(req)) + req.Header.Set(connectHeaderProtocolVersion, connectProtocolVersion) + assert.True(t, writer.IsSupported(req)) + }) + t.Run("UnaryGET", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + assert.False(t, writer.IsSupported(req)) + query := req.URL.Query() + query.Set(connectUnaryConnectQueryParameter, connectUnaryConnectQueryValue) + req.URL.RawQuery = query.Encode() + assert.True(t, writer.IsSupported(req)) + }) + t.Run("Stream", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "http://localhost", nil) + req.Header.Set("Content-Type", connectStreamingContentTypePrefix+codecNameJSON) + assert.True(t, writer.IsSupported(req)) // ignores WithRequireConnectProtocolHeader + req.Header.Set(connectHeaderProtocolVersion, connectProtocolVersion) + assert.True(t, writer.IsSupported(req)) + }) + }) +} diff --git a/option.go b/option.go index b41a24a2..9cc0c2ee 100644 --- a/option.go +++ b/option.go @@ -159,7 +159,7 @@ func WithRecover(handle func(context.Context, Spec, http.Header, any) error) Han // header. This ensures that HTTP proxies and net/http middleware can easily // identify valid Connect requests, even if they use a common Content-Type like // application/json. However, it makes ad-hoc requests with tools like cURL -// more laborious. +// more laborious. Streaming requests are not affected by this option. // // This option has no effect if the client uses the gRPC or gRPC-Web protocols. func WithRequireConnectProtocolHeader() HandlerOption { diff --git a/protocol_connect.go b/protocol_connect.go index 348941ed..e0736442 100644 --- a/protocol_connect.go +++ b/protocol_connect.go @@ -48,7 +48,7 @@ const ( connectFlagEnvelopeEndStream = 0b00000010 connectUnaryContentTypePrefix = "application/" - connectUnaryContentTypeJSON = connectUnaryContentTypePrefix + "json" + connectUnaryContentTypeJSON = connectUnaryContentTypePrefix + codecNameJSON connectStreamingContentTypePrefix = "application/connect+" connectUnaryEncodingQueryParameter = "encoding" @@ -172,21 +172,9 @@ func (h *connectHandler) NewConn( if failed == nil { failed = checkServerStreamsCanFlush(h.Spec, responseWriter) } - if failed == nil && request.Method == http.MethodGet { - version := query.Get(connectUnaryConnectQueryParameter) - if version == "" && h.RequireConnectProtocolHeader { - failed = errorf(CodeInvalidArgument, "missing required query parameter: set %s to %q", connectUnaryConnectQueryParameter, connectUnaryConnectQueryValue) - } else if version != "" && version != connectUnaryConnectQueryValue { - failed = errorf(CodeInvalidArgument, "%s must be %q: got %q", connectUnaryConnectQueryParameter, connectUnaryConnectQueryValue, version) - } - } - if failed == nil && request.Method == http.MethodPost { - version := getHeaderCanonical(request.Header, connectHeaderProtocolVersion) - if version == "" && h.RequireConnectProtocolHeader { - failed = errorf(CodeInvalidArgument, "missing required header: set %s to %q", connectHeaderProtocolVersion, connectProtocolVersion) - } else if version != "" && version != connectProtocolVersion { - failed = errorf(CodeInvalidArgument, "%s must be %q: got %q", connectHeaderProtocolVersion, connectProtocolVersion, version) - } + if failed == nil { + required := h.RequireConnectProtocolHeader && (h.Spec.StreamType == StreamTypeUnary) + failed = connectCheckProtocolVersion(request, required) } var requestBody io.ReadCloser @@ -1442,3 +1430,25 @@ func connectValidateStreamResponseContentType(requestCodecName string, streamTyp } return nil } + +func connectCheckProtocolVersion(request *http.Request, required bool) *Error { + switch request.Method { + case http.MethodGet: + version := request.URL.Query().Get(connectUnaryConnectQueryParameter) + if version == "" && required { + return errorf(CodeInvalidArgument, "missing required query parameter: set %s to %q", connectUnaryConnectQueryParameter, connectUnaryConnectQueryValue) + } else if version != "" && version != connectUnaryConnectQueryValue { + return errorf(CodeInvalidArgument, "%s must be %q: got %q", connectUnaryConnectQueryParameter, connectUnaryConnectQueryValue, version) + } + case http.MethodPost: + version := getHeaderCanonical(request.Header, connectHeaderProtocolVersion) + if version == "" && required { + return errorf(CodeInvalidArgument, "missing required header: set %s to %q", connectHeaderProtocolVersion, connectProtocolVersion) + } else if version != "" && version != connectProtocolVersion { + return errorf(CodeInvalidArgument, "%s must be %q: got %q", connectHeaderProtocolVersion, connectProtocolVersion, version) + } + default: + return errorf(CodeInvalidArgument, "unsupported method: %q", request.Method) + } + return nil +}