diff --git a/.circleci/config.yml b/.circleci/config.yml deleted file mode 100644 index 5619dfa..0000000 --- a/.circleci/config.yml +++ /dev/null @@ -1,70 +0,0 @@ -version: 2.1 - -jobs: - "test": - parameters: - version: - type: string - default: "latest" - golint: - type: boolean - default: true - modules: - type: boolean - default: true - goproxy: - type: string - default: "" - docker: - - image: "circleci/golang:<< parameters.version >>" - working_directory: /go/src/github.com/gorilla/handlers - environment: - GO111MODULE: "on" - GOPROXY: "<< parameters.goproxy >>" - steps: - - checkout - - run: - name: "Print the Go version" - command: > - go version - - run: - name: "Fetch dependencies" - command: > - if [[ << parameters.modules >> = true ]]; then - go mod download - export GO111MODULE=on - else - go get -v ./... - fi - # Only run gofmt, vet & lint against the latest Go version - - run: - name: "Run golint" - command: > - if [ << parameters.version >> = "latest" ] && [ << parameters.golint >> = true ]; then - go get -u golang.org/x/lint/golint - golint ./... - fi - - run: - name: "Run gofmt" - command: > - if [[ << parameters.version >> = "latest" ]]; then - diff -u <(echo -n) <(gofmt -d -e .) - fi - - run: - name: "Run go vet" - command: > - if [[ << parameters.version >> = "latest" ]]; then - go vet -v ./... - fi - - run: - name: "Run go test (+ race detector)" - command: > - go test -v -race ./... - -workflows: - tests: - jobs: - - test: - matrix: - parameters: - version: ["latest", "1.15", "1.14", "1.13", "1.12", "1.11"] diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000..c6b74c3 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,20 @@ +; https://editorconfig.org/ + +root = true + +[*] +insert_final_newline = true +charset = utf-8 +trim_trailing_whitespace = true +indent_style = space +indent_size = 2 + +[{Makefile,go.mod,go.sum,*.go,.gitmodules}] +indent_style = tab +indent_size = 4 + +[*.md] +indent_size = 4 +trim_trailing_whitespace = false + +eclint_indent_style = unset \ No newline at end of file diff --git a/.github/release-drafter.yml b/.github/release-drafter.yml deleted file mode 100644 index 2db2e13..0000000 --- a/.github/release-drafter.yml +++ /dev/null @@ -1,8 +0,0 @@ -# Config for https://github.com/apps/release-drafter -template: | - - - - ## CHANGELOG - - $CHANGES diff --git a/.github/stale.yml b/.github/stale.yml deleted file mode 100644 index de8a678..0000000 --- a/.github/stale.yml +++ /dev/null @@ -1,12 +0,0 @@ -daysUntilStale: 60 -daysUntilClose: 7 -# Issues with these labels will never be considered stale -exemptLabels: - - v2 - - needs-review - - work-required -staleLabel: stale -markComment: > - This issue has been automatically marked as stale because it hasn't seen - a recent update. It'll be automatically closed in a few days. -closeComment: false diff --git a/.github/workflows/issues.yml b/.github/workflows/issues.yml new file mode 100644 index 0000000..055ca82 --- /dev/null +++ b/.github/workflows/issues.yml @@ -0,0 +1,20 @@ +# Add issues or pull-requests created to the project. +name: Add issue or pull request to Project + +on: + issues: + types: + - opened + pull_request: + types: + - opened + +jobs: + add-to-project: + runs-on: ubuntu-latest + steps: + - name: Add issue to project + uses: actions/add-to-project@v0.5.0 + with: + project-url: https://github.com/orgs/gorilla/projects/4 + github-token: ${{ secrets.ADD_TO_PROJECT_TOKEN }} \ No newline at end of file diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..a93d214 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,55 @@ +name: CI +on: + push: + branches: + - main + pull_request: + branches: + - main + +permissions: + contents: read + +jobs: + verify-and-test: + strategy: + matrix: + go: ['1.19','1.20'] + os: [ubuntu-latest, macos-latest, windows-latest] + fail-fast: true + runs-on: ${{ matrix.os }} + steps: + - name: Checkout Code + uses: actions/checkout@v3 + + - name: Setup Go ${{ matrix.go }} + uses: actions/setup-go@v4 + with: + go-version: ${{ matrix.go }} + cache: false + + - name: Run golangci-lint + uses: golangci/golangci-lint-action@v3 + with: + version: v1.53 + args: --timeout=5m + + - name: Run gosec + if: matrix.os == 'ubuntu-latest' + uses: securego/gosec@master + with: + args: ./... + + - name: Run govulncheck + uses: golang/govulncheck-action@v1 + with: + go-version-input: ${{ matrix.go }} + go-package: ./... + + - name: Run tests + run: go test -race -cover -coverprofile=coverage -covermode=atomic -v ./... + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v3 + with: + files: ./coverage \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..577a89e --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +# Output of the go test coverage tool +coverage.coverprofile diff --git a/LICENSE b/LICENSE index 66ea3c8..bb9d80b 100644 --- a/LICENSE +++ b/LICENSE @@ -1,22 +1,27 @@ -Copyright (c) 2013 The Gorilla Handlers Authors. All rights reserved. +Copyright (c) 2023 The Gorilla Authors. All rights reserved. Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: +modification, are permitted provided that the following conditions are +met: - Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. - Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND -ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..003b784 --- /dev/null +++ b/Makefile @@ -0,0 +1,34 @@ +GO_LINT=$(shell which golangci-lint 2> /dev/null || echo '') +GO_LINT_URI=github.com/golangci/golangci-lint/cmd/golangci-lint@latest + +GO_SEC=$(shell which gosec 2> /dev/null || echo '') +GO_SEC_URI=github.com/securego/gosec/v2/cmd/gosec@latest + +GO_VULNCHECK=$(shell which govulncheck 2> /dev/null || echo '') +GO_VULNCHECK_URI=golang.org/x/vuln/cmd/govulncheck@latest + +.PHONY: verify +verify: sec govulncheck lint test + +.PHONY: lint +lint: + $(if $(GO_LINT), ,go install $(GO_LINT_URI)) + @echo "##### Running golangci-lint #####" + golangci-lint run -v + +.PHONY: sec +sec: + $(if $(GO_SEC), ,go install $(GO_SEC_URI)) + @echo "##### Running gosec #####" + gosec ./... + +.PHONY: govulncheck +govulncheck: + $(if $(GO_VULNCHECK), ,go install $(GO_VULNCHECK_URI)) + @echo "##### Running govulncheck #####" + govulncheck ./... + +.PHONY: test +test: + @echo "##### Running tests #####" + go test -race -cover -coverprofile=coverage.coverprofile -covermode=atomic -v ./... diff --git a/README.md b/README.md index d9b9d46..02555b2 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,8 @@ -gorilla/handlers -================ +# gorilla/handlers + +![Testing](https://github.com/gorilla/handlers/actions/workflows/test.yml/badge.svg) +[![Codecov](https://codecov.io/github/gorilla/handlers/branch/main/graph/badge.svg)](https://codecov.io/github/gorilla/handlers) [![GoDoc](https://godoc.org/github.com/gorilla/handlers?status.svg)](https://godoc.org/github.com/gorilla/handlers) -[![CircleCI](https://circleci.com/gh/gorilla/handlers.svg?style=svg)](https://circleci.com/gh/gorilla/handlers) [![Sourcegraph](https://sourcegraph.com/github.com/gorilla/handlers/-/badge.svg)](https://sourcegraph.com/github.com/gorilla/handlers?badge) Package handlers is a collection of handlers (aka "HTTP middleware") for use diff --git a/canonical.go b/canonical.go index 8437fef..7121f53 100644 --- a/canonical.go +++ b/canonical.go @@ -21,12 +21,11 @@ type canonical struct { // // Example: // -// r := mux.NewRouter() -// canonical := handlers.CanonicalHost("http://www.gorillatoolkit.org", 302) -// r.HandleFunc("/route", YourHandler) -// -// log.Fatal(http.ListenAndServe(":7000", canonical(r))) +// r := mux.NewRouter() +// canonical := handlers.CanonicalHost("http://www.gorillatoolkit.org", 302) +// r.HandleFunc("/route", YourHandler) // +// log.Fatal(http.ListenAndServe(":7000", canonical(r))) func CanonicalHost(domain string, code int) func(h http.Handler) http.Handler { fn := func(h http.Handler) http.Handler { return canonical{h, domain, code} diff --git a/canonical_test.go b/canonical_test.go index 615e4b0..2b11ccf 100644 --- a/canonical_test.go +++ b/canonical_test.go @@ -32,7 +32,7 @@ func TestCanonicalHost(t *testing.T) { gorilla := "http://www.gorillatoolkit.org" rr := httptest.NewRecorder() - r := newRequest("GET", "http://www.example.com/") + r := newRequest(http.MethodGet, "http://www.example.com/") testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) @@ -46,7 +46,6 @@ func TestCanonicalHost(t *testing.T) { if rr.Header().Get("Location") != gorilla+r.URL.Path { t.Fatalf("bad re-direct: got %q want %q", rr.Header().Get("Location"), gorilla+r.URL.Path) } - } func TestKeepsQueryString(t *testing.T) { @@ -54,7 +53,7 @@ func TestKeepsQueryString(t *testing.T) { rr := httptest.NewRecorder() querystring := url.Values{"q": {"golang"}, "format": {"json"}}.Encode() - r := newRequest("GET", "http://www.example.com/search?"+querystring) + r := newRequest(http.MethodGet, "http://www.example.com/search?"+querystring) testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) CanonicalHost(google, http.StatusFound)(testHandler).ServeHTTP(rr, r) @@ -67,7 +66,7 @@ func TestKeepsQueryString(t *testing.T) { func TestBadDomain(t *testing.T) { rr := httptest.NewRecorder() - r := newRequest("GET", "http://www.example.com/") + r := newRequest(http.MethodGet, "http://www.example.com/") testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) @@ -81,7 +80,7 @@ func TestBadDomain(t *testing.T) { func TestEmptyHost(t *testing.T) { rr := httptest.NewRecorder() - r := newRequest("GET", "http://www.example.com/") + r := newRequest(http.MethodGet, "http://www.example.com/") testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) @@ -97,7 +96,7 @@ func TestHeaderWrites(t *testing.T) { gorilla := "http://www.gorillatoolkit.org" testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(200) + w.WriteHeader(http.StatusOK) }) // Catch the log output to ensure we don't write multiple headers. diff --git a/compress.go b/compress.go index 1e95f1c..d6f5895 100644 --- a/compress.go +++ b/compress.go @@ -44,13 +44,13 @@ type flusher interface { Flush() error } -func (w *compressResponseWriter) Flush() { +func (cw *compressResponseWriter) Flush() { // Flush compressed data if compressor supports it. - if f, ok := w.compressor.(flusher); ok { - f.Flush() + if f, ok := cw.compressor.(flusher); ok { + _ = f.Flush() } // Flush HTTP response. - if f, ok := w.w.(http.Flusher); ok { + if f, ok := cw.w.(http.Flusher); ok { f.Flush() } } diff --git a/compress_test.go b/compress_test.go index b9457cd..25fdbe9 100644 --- a/compress_test.go +++ b/compress_test.go @@ -9,7 +9,7 @@ import ( "bytes" "compress/gzip" "io" - "io/ioutil" + "log" "net" "net/http" "net/http/httptest" @@ -27,10 +27,13 @@ func compressedRequest(w *httptest.ResponseRecorder, compression string) { w.Header().Set("Content-Length", strconv.Itoa(9*1024)) w.Header().Set("Content-Type", contentType) for i := 0; i < 1024; i++ { - io.WriteString(w, "Gorilla!\n") + _, err := io.WriteString(w, "Gorilla!\n") + if err != nil { + log.Printf("error on writting to http.ResponseWriter: %v\n", err) + } } })).ServeHTTP(w, &http.Request{ - Method: "GET", + Method: http.MethodGet, Header: http.Header{ acceptEncoding: []string{compression}, }, @@ -40,19 +43,20 @@ func compressedRequest(w *httptest.ResponseRecorder, compression string) { func TestCompressHandlerNoCompression(t *testing.T) { w := httptest.NewRecorder() compressedRequest(w, "") - if enc := w.HeaderMap.Get("Content-Encoding"); enc != "" { + resp := w.Result() + if enc := resp.Header.Get("Content-Encoding"); enc != "" { t.Errorf("wrong content encoding, got %q want %q", enc, "") } - if ct := w.HeaderMap.Get("Content-Type"); ct != contentType { + if ct := resp.Header.Get("Content-Type"); ct != contentType { t.Errorf("wrong content type, got %q want %q", ct, contentType) } if w.Body.Len() != 1024*9 { t.Errorf("wrong len, got %d want %d", w.Body.Len(), 1024*9) } - if l := w.HeaderMap.Get("Content-Length"); l != "9216" { + if l := resp.Header.Get("Content-Length"); l != "9216" { t.Errorf("wrong content-length. got %q expected %d", l, 1024*9) } - if v := w.HeaderMap.Get("Vary"); v != acceptEncoding { + if v := resp.Header.Get("Vary"); v != acceptEncoding { t.Errorf("wrong vary. got %s expected %s", v, acceptEncoding) } } @@ -114,7 +118,7 @@ func TestAcceptEncodingIsDropped(t *testing.T) { w := httptest.NewRecorder() ch.ServeHTTP(w, &http.Request{ - Method: "GET", + Method: http.MethodGet, Header: http.Header{ acceptEncoding: []string{tCase.compression}, }, @@ -125,16 +129,17 @@ func TestAcceptEncodingIsDropped(t *testing.T) { func TestCompressHandlerGzip(t *testing.T) { w := httptest.NewRecorder() compressedRequest(w, "gzip") - if w.HeaderMap.Get("Content-Encoding") != "gzip" { - t.Errorf("wrong content encoding, got %q want %q", w.HeaderMap.Get("Content-Encoding"), "gzip") + resp := w.Result() + if resp.Header.Get("Content-Encoding") != "gzip" { + t.Errorf("wrong content encoding, got %q want %q", resp.Header.Get("Content-Encoding"), "gzip") } - if w.HeaderMap.Get("Content-Type") != "text/plain; charset=utf-8" { - t.Errorf("wrong content type, got %s want %s", w.HeaderMap.Get("Content-Type"), "text/plain; charset=utf-8") + if resp.Header.Get("Content-Type") != "text/plain; charset=utf-8" { + t.Errorf("wrong content type, got %s want %s", resp.Header.Get("Content-Type"), "text/plain; charset=utf-8") } if w.Body.Len() != 72 { t.Errorf("wrong len, got %d want %d", w.Body.Len(), 72) } - if l := w.HeaderMap.Get("Content-Length"); l != "" { + if l := resp.Header.Get("Content-Length"); l != "" { t.Errorf("wrong content-length. got %q expected %q", l, "") } } @@ -142,11 +147,12 @@ func TestCompressHandlerGzip(t *testing.T) { func TestCompressHandlerDeflate(t *testing.T) { w := httptest.NewRecorder() compressedRequest(w, "deflate") - if w.HeaderMap.Get("Content-Encoding") != "deflate" { - t.Fatalf("wrong content encoding, got %q want %q", w.HeaderMap.Get("Content-Encoding"), "deflate") + resp := w.Result() + if resp.Header.Get("Content-Encoding") != "deflate" { + t.Fatalf("wrong content encoding, got %q want %q", resp.Header.Get("Content-Encoding"), "deflate") } - if w.HeaderMap.Get("Content-Type") != "text/plain; charset=utf-8" { - t.Fatalf("wrong content type, got %s want %s", w.HeaderMap.Get("Content-Type"), "text/plain; charset=utf-8") + if resp.Header.Get("Content-Type") != "text/plain; charset=utf-8" { + t.Fatalf("wrong content type, got %s want %s", resp.Header.Get("Content-Type"), "text/plain; charset=utf-8") } if w.Body.Len() != 54 { t.Fatalf("wrong len, got %d want %d", w.Body.Len(), 54) @@ -156,11 +162,12 @@ func TestCompressHandlerDeflate(t *testing.T) { func TestCompressHandlerGzipDeflate(t *testing.T) { w := httptest.NewRecorder() compressedRequest(w, "gzip, deflate ") - if w.HeaderMap.Get("Content-Encoding") != "gzip" { - t.Fatalf("wrong content encoding, got %q want %q", w.HeaderMap.Get("Content-Encoding"), "gzip") + resp := w.Result() + if resp.Header.Get("Content-Encoding") != "gzip" { + t.Fatalf("wrong content encoding, got %q want %q", resp.Header.Get("Content-Encoding"), "gzip") } - if w.HeaderMap.Get("Content-Type") != "text/plain; charset=utf-8" { - t.Fatalf("wrong content type, got %s want %s", w.HeaderMap.Get("Content-Type"), "text/plain; charset=utf-8") + if resp.Header.Get("Content-Type") != "text/plain; charset=utf-8" { + t.Fatalf("wrong content type, got %s want %s", resp.Header.Get("Content-Type"), "text/plain; charset=utf-8") } } @@ -168,13 +175,13 @@ func TestCompressHandlerGzipDeflate(t *testing.T) { // to use a real http server to trigger the net/http sendfile special // case. func TestCompressFile(t *testing.T) { - dir, err := ioutil.TempDir("", "gorilla_compress") + dir, err := os.MkdirTemp("", "gorilla_compress") if err != nil { t.Fatal(err) } defer os.RemoveAll(dir) - err = ioutil.WriteFile(filepath.Join(dir, "hello.txt"), []byte("hello"), 0644) + err = os.WriteFile(filepath.Join(dir, "hello.txt"), []byte("hello"), 0o644) if err != nil { t.Fatal(err) } @@ -183,7 +190,7 @@ func TestCompressFile(t *testing.T) { defer s.Close() url := &url.URL{Scheme: "http", Host: s.Listener.Addr().String(), Path: "/hello.txt"} - req, err := http.NewRequest("GET", url.String(), nil) + req, err := http.NewRequest(http.MethodGet, url.String(), nil) if err != nil { t.Fatal(err) } @@ -232,34 +239,25 @@ func (fullyFeaturedResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) return nil, nil, nil } -// CloseNotify implements the http.CloseNotifier interface. -func (fullyFeaturedResponseWriter) CloseNotify() <-chan bool { - return nil -} - func TestCompressHandlerPreserveInterfaces(t *testing.T) { // Compile time validation fullyFeaturedResponseWriter implements all the // interfaces we're asserting in the test case below. var ( - _ http.Flusher = fullyFeaturedResponseWriter{} - _ http.CloseNotifier = fullyFeaturedResponseWriter{} - _ http.Hijacker = fullyFeaturedResponseWriter{} + _ http.Flusher = fullyFeaturedResponseWriter{} + _ http.Hijacker = fullyFeaturedResponseWriter{} ) var h http.Handler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { comp := r.Header.Get(acceptEncoding) if _, ok := rw.(http.Flusher); !ok { t.Errorf("ResponseWriter lost http.Flusher interface for %q", comp) } - if _, ok := rw.(http.CloseNotifier); !ok { - t.Errorf("ResponseWriter lost http.CloseNotifier interface for %q", comp) - } if _, ok := rw.(http.Hijacker); !ok { t.Errorf("ResponseWriter lost http.Hijacker interface for %q", comp) } }) h = CompressHandler(h) var rw fullyFeaturedResponseWriter - r, err := http.NewRequest("GET", "/", nil) + r, err := http.NewRequest(http.MethodGet, "/", nil) if err != nil { t.Fatalf("Failed to create test request: %v", err) } @@ -291,7 +289,7 @@ func TestCompressHandlerDoesntInventInterfaces(t *testing.T) { h = CompressHandler(h) var rw paltryResponseWriter - r, err := http.NewRequest("GET", "/", nil) + r, err := http.NewRequest(http.MethodGet, "/", nil) if err != nil { t.Fatalf("Failed to create test request: %v", err) } diff --git a/cors.go b/cors.go index 0dcdffb..42892bd 100644 --- a/cors.go +++ b/cors.go @@ -27,9 +27,9 @@ type OriginValidator func(string) bool var ( defaultCorsOptionStatusCode = 200 - defaultCorsMethods = []string{"GET", "HEAD", "POST"} + defaultCorsMethods = []string{http.MethodGet, "HEAD", http.MethodPost} defaultCorsHeaders = []string{"Accept", "Accept-Language", "Content-Language", "Origin"} - // (WebKit/Safari v9 sends the Origin header by default in AJAX requests) + // (WebKit/Safari v9 sends the Origin header by default in AJAX requests). ) const ( @@ -101,10 +101,8 @@ func (ch *cors) ServeHTTP(w http.ResponseWriter, r *http.Request) { if !ch.isMatch(method, defaultCorsMethods) { w.Header().Set(corsAllowMethodsHeader, method) } - } else { - if len(ch.exposedHeaders) > 0 { - w.Header().Set(corsExposeHeadersHeader, strings.Join(ch.exposedHeaders, ",")) - } + } else if len(ch.exposedHeaders) > 0 { + w.Header().Set(corsExposeHeadersHeader, strings.Join(ch.exposedHeaders, ",")) } if ch.allowCredentials { @@ -141,22 +139,21 @@ func (ch *cors) ServeHTTP(w http.ResponseWriter, r *http.Request) { // CORS provides Cross-Origin Resource Sharing middleware. // Example: // -// import ( -// "net/http" -// -// "github.com/gorilla/handlers" -// "github.com/gorilla/mux" -// ) +// import ( +// "net/http" // -// func main() { -// r := mux.NewRouter() -// r.HandleFunc("/users", UserEndpoint) -// r.HandleFunc("/projects", ProjectEndpoint) +// "github.com/gorilla/handlers" +// "github.com/gorilla/mux" +// ) // -// // Apply the CORS middleware to our top-level router, with the defaults. -// http.ListenAndServe(":8000", handlers.CORS()(r)) -// } +// func main() { +// r := mux.NewRouter() +// r.HandleFunc("/users", UserEndpoint) +// r.HandleFunc("/projects", ProjectEndpoint) // +// // Apply the CORS middleware to our top-level router, with the defaults. +// http.ListenAndServe(":8000", handlers.CORS()(r)) +// } func CORS(opts ...CORSOption) func(http.Handler) http.Handler { return func(h http.Handler) http.Handler { ch := parseCORSOptions(opts...) @@ -174,7 +171,7 @@ func parseCORSOptions(opts ...CORSOption) *cors { } for _, option := range opts { - option(ch) + _ = option(ch) //TODO: @bharat-rajani, return error to caller if not nil? } return ch diff --git a/cors_test.go b/cors_test.go index ee7d184..777a420 100644 --- a/cors_test.go +++ b/cors_test.go @@ -8,20 +8,20 @@ import ( ) func TestDefaultCORSHandlerReturnsOk(t *testing.T) { - r := newRequest("GET", "http://www.example.com/") + r := newRequest(http.MethodGet, "http://www.example.com/") rr := httptest.NewRecorder() testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) CORS()(testHandler).ServeHTTP(rr, r) - - if got, want := rr.Code, http.StatusOK; got != want { + resp := rr.Result() + if got, want := resp.StatusCode, http.StatusOK; got != want { t.Fatalf("bad status: got %v want %v", got, want) } } func TestDefaultCORSHandlerReturnsOkWithOrigin(t *testing.T) { - r := newRequest("GET", "http://www.example.com/") + r := newRequest(http.MethodGet, "http://www.example.com/") r.Header.Set("Origin", r.URL.String()) rr := httptest.NewRecorder() @@ -29,14 +29,14 @@ func TestDefaultCORSHandlerReturnsOkWithOrigin(t *testing.T) { testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) CORS()(testHandler).ServeHTTP(rr, r) - - if got, want := rr.Code, http.StatusOK; got != want { + resp := rr.Result() + if got, want := resp.StatusCode, http.StatusOK; got != want { t.Fatalf("bad status: got %v want %v", got, want) } } func TestCORSHandlerIgnoreOptionsFallsThrough(t *testing.T) { - r := newRequest("OPTIONS", "http://www.example.com/") + r := newRequest(http.MethodOptions, "http://www.example.com/") r.Header.Set("Origin", r.URL.String()) rr := httptest.NewRecorder() @@ -46,68 +46,70 @@ func TestCORSHandlerIgnoreOptionsFallsThrough(t *testing.T) { }) CORS(IgnoreOptions())(testHandler).ServeHTTP(rr, r) + resp := rr.Result() - if got, want := rr.Code, http.StatusTeapot; got != want { + if got, want := resp.StatusCode, http.StatusTeapot; got != want { t.Fatalf("bad status: got %v want %v", got, want) } } func TestCORSHandlerSetsExposedHeaders(t *testing.T) { // Test default configuration. - r := newRequest("GET", "http://www.example.com/") + r := newRequest(http.MethodGet, "http://www.example.com/") r.Header.Set("Origin", r.URL.String()) rr := httptest.NewRecorder() - testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) CORS(ExposedHeaders([]string{"X-CORS-TEST"}))(testHandler).ServeHTTP(rr, r) + resp := rr.Result() - if got, want := rr.Code, http.StatusOK; got != want { + if got, want := resp.StatusCode, http.StatusOK; got != want { t.Fatalf("bad status: got %v want %v", got, want) } - header := rr.HeaderMap.Get(corsExposeHeadersHeader) + header := resp.Header.Get(corsExposeHeadersHeader) if got, want := header, "X-Cors-Test"; got != want { t.Fatalf("bad header: expected %q header, got empty header for method.", want) } } func TestCORSHandlerUnsetRequestMethodForPreflightBadRequest(t *testing.T) { - r := newRequest("OPTIONS", "http://www.example.com/") + r := newRequest(http.MethodOptions, "http://www.example.com/") r.Header.Set("Origin", r.URL.String()) rr := httptest.NewRecorder() - testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) CORS(AllowedMethods([]string{"DELETE"}))(testHandler).ServeHTTP(rr, r) + resp := rr.Result() - if got, want := rr.Code, http.StatusBadRequest; got != want { + if got, want := resp.StatusCode, http.StatusBadRequest; got != want { t.Fatalf("bad status: got %v want %v", got, want) } } func TestCORSHandlerInvalidRequestMethodForPreflightMethodNotAllowed(t *testing.T) { - r := newRequest("OPTIONS", "http://www.example.com/") + r := newRequest(http.MethodOptions, "http://www.example.com/") r.Header.Set("Origin", r.URL.String()) - r.Header.Set(corsRequestMethodHeader, "DELETE") + r.Header.Set(corsRequestMethodHeader, http.MethodDelete) rr := httptest.NewRecorder() testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) CORS()(testHandler).ServeHTTP(rr, r) + resp := rr.Result() - if got, want := rr.Code, http.StatusMethodNotAllowed; got != want { + if got, want := resp.StatusCode, http.StatusMethodNotAllowed; got != want { t.Fatalf("bad status: got %v want %v", got, want) } } func TestCORSHandlerOptionsRequestMustNotBePassedToNextHandler(t *testing.T) { - r := newRequest("OPTIONS", "http://www.example.com/") + r := newRequest(http.MethodOptions, "http://www.example.com/") r.Header.Set("Origin", r.URL.String()) - r.Header.Set(corsRequestMethodHeader, "GET") + r.Header.Set(corsRequestMethodHeader, http.MethodGet) rr := httptest.NewRecorder() @@ -116,17 +118,18 @@ func TestCORSHandlerOptionsRequestMustNotBePassedToNextHandler(t *testing.T) { }) CORS()(testHandler).ServeHTTP(rr, r) + resp := rr.Result() - if got, want := rr.Code, http.StatusOK; got != want { + if got, want := resp.StatusCode, http.StatusOK; got != want { t.Fatalf("bad status: got %v want %v", got, want) } } func TestCORSHandlerOptionsRequestMustNotBePassedToNextHandlerWithCustomStatusCode(t *testing.T) { statusCode := http.StatusNoContent - r := newRequest("OPTIONS", "http://www.example.com/") + r := newRequest(http.MethodOptions, "http://www.example.com/") r.Header.Set("Origin", r.URL.String()) - r.Header.Set(corsRequestMethodHeader, "GET") + r.Header.Set(corsRequestMethodHeader, http.MethodGet) rr := httptest.NewRecorder() @@ -135,16 +138,17 @@ func TestCORSHandlerOptionsRequestMustNotBePassedToNextHandlerWithCustomStatusCo }) CORS(OptionStatusCode(statusCode))(testHandler).ServeHTTP(rr, r) + resp := rr.Result() - if got, want := rr.Code, statusCode; got != want { + if got, want := resp.StatusCode, statusCode; got != want { t.Fatalf("bad status: got %v want %v", got, want) } } func TestCORSHandlerOptionsRequestMustNotBePassedToNextHandlerWhenOriginNotAllowed(t *testing.T) { - r := newRequest("OPTIONS", "http://www.example.com/") + r := newRequest(http.MethodOptions, "http://www.example.com/") r.Header.Set("Origin", r.URL.String()) - r.Header.Set(corsRequestMethodHeader, "GET") + r.Header.Set(corsRequestMethodHeader, http.MethodGet) rr := httptest.NewRecorder() @@ -153,36 +157,38 @@ func TestCORSHandlerOptionsRequestMustNotBePassedToNextHandlerWhenOriginNotAllow }) CORS(AllowedOrigins([]string{}))(testHandler).ServeHTTP(rr, r) + resp := rr.Result() - if got, want := rr.Code, http.StatusOK; got != want { + if got, want := resp.StatusCode, http.StatusOK; got != want { t.Fatalf("bad status: got %v want %v", got, want) } } func TestCORSHandlerAllowedMethodForPreflight(t *testing.T) { - r := newRequest("OPTIONS", "http://www.example.com/") + r := newRequest(http.MethodOptions, "http://www.example.com/") r.Header.Set("Origin", r.URL.String()) - r.Header.Set(corsRequestMethodHeader, "DELETE") + r.Header.Set(corsRequestMethodHeader, http.MethodDelete) - rr := httptest.NewRecorder() + rc := httptest.NewRecorder() testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) - CORS(AllowedMethods([]string{"DELETE"}))(testHandler).ServeHTTP(rr, r) + CORS(AllowedMethods([]string{"DELETE"}))(testHandler).ServeHTTP(rc, r) - if got, want := rr.Code, http.StatusOK; got != want { + resp := rc.Result() + if got, want := resp.StatusCode, http.StatusOK; got != want { t.Fatalf("bad status: got %v want %v", got, want) } - header := rr.HeaderMap.Get(corsAllowMethodsHeader) - if got, want := header, "DELETE"; got != want { - t.Fatalf("bad header: expected %q method header, got %q header.", want, got) + header := resp.Header.Get(corsAllowMethodsHeader) + if header != http.MethodDelete { + t.Fatalf("bad header: expected %q method header, got %q header.", http.MethodDelete, header) } } func TestCORSHandlerAllowMethodsNotSetForSimpleRequestPreflight(t *testing.T) { for _, method := range defaultCorsMethods { - r := newRequest("OPTIONS", "http://www.example.com/") + r := newRequest(http.MethodOptions, "http://www.example.com/") r.Header.Set("Origin", r.URL.String()) r.Header.Set(corsRequestMethodHeader, method) @@ -191,12 +197,13 @@ func TestCORSHandlerAllowMethodsNotSetForSimpleRequestPreflight(t *testing.T) { testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) CORS()(testHandler).ServeHTTP(rr, r) + resp := rr.Result() - if got, want := rr.Code, http.StatusOK; got != want { + if got, want := resp.StatusCode, http.StatusOK; got != want { t.Fatalf("bad status: got %v want %v", got, want) } - header := rr.HeaderMap.Get(corsAllowMethodsHeader) + header := resp.Header.Get(corsAllowMethodsHeader) if got, want := header, ""; got != want { t.Fatalf("bad header: expected %q method header, got %q.", want, got) } @@ -205,9 +212,9 @@ func TestCORSHandlerAllowMethodsNotSetForSimpleRequestPreflight(t *testing.T) { func TestCORSHandlerAllowedHeaderNotSetForSimpleRequestPreflight(t *testing.T) { for _, simpleHeader := range defaultCorsHeaders { - r := newRequest("OPTIONS", "http://www.example.com/") + r := newRequest(http.MethodOptions, "http://www.example.com/") r.Header.Set("Origin", r.URL.String()) - r.Header.Set(corsRequestMethodHeader, "GET") + r.Header.Set(corsRequestMethodHeader, http.MethodGet) r.Header.Set(corsRequestHeadersHeader, simpleHeader) rr := httptest.NewRecorder() @@ -215,12 +222,13 @@ func TestCORSHandlerAllowedHeaderNotSetForSimpleRequestPreflight(t *testing.T) { testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) CORS()(testHandler).ServeHTTP(rr, r) + resp := rr.Result() - if got, want := rr.Code, http.StatusOK; got != want { + if got, want := resp.StatusCode, http.StatusOK; got != want { t.Fatalf("bad status: got %v want %v", got, want) } - header := rr.HeaderMap.Get(corsAllowHeadersHeader) + header := resp.Header.Get(corsAllowHeadersHeader) if got, want := header, ""; got != want { t.Fatalf("bad header: expected %q header, got %q.", want, got) } @@ -228,9 +236,9 @@ func TestCORSHandlerAllowedHeaderNotSetForSimpleRequestPreflight(t *testing.T) { } func TestCORSHandlerAllowedHeaderForPreflight(t *testing.T) { - r := newRequest("OPTIONS", "http://www.example.com/") + r := newRequest(http.MethodOptions, "http://www.example.com/") r.Header.Set("Origin", r.URL.String()) - r.Header.Set(corsRequestMethodHeader, "POST") + r.Header.Set(corsRequestMethodHeader, http.MethodPost) r.Header.Set(corsRequestHeadersHeader, "Content-Type") rr := httptest.NewRecorder() @@ -238,21 +246,22 @@ func TestCORSHandlerAllowedHeaderForPreflight(t *testing.T) { testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) CORS(AllowedHeaders([]string{"Content-Type"}))(testHandler).ServeHTTP(rr, r) + resp := rr.Result() - if got, want := rr.Code, http.StatusOK; got != want { + if got, want := resp.StatusCode, http.StatusOK; got != want { t.Fatalf("bad status: got %v want %v", got, want) } - header := rr.HeaderMap.Get(corsAllowHeadersHeader) + header := resp.Header.Get(corsAllowHeadersHeader) if got, want := header, "Content-Type"; got != want { t.Fatalf("bad header: expected %q header, got %q header.", want, got) } } func TestCORSHandlerInvalidHeaderForPreflightForbidden(t *testing.T) { - r := newRequest("OPTIONS", "http://www.example.com/") + r := newRequest(http.MethodOptions, "http://www.example.com/") r.Header.Set("Origin", r.URL.String()) - r.Header.Set(corsRequestMethodHeader, "POST") + r.Header.Set(corsRequestMethodHeader, http.MethodPost) r.Header.Set(corsRequestHeadersHeader, "Content-Type") rr := httptest.NewRecorder() @@ -260,35 +269,37 @@ func TestCORSHandlerInvalidHeaderForPreflightForbidden(t *testing.T) { testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) CORS()(testHandler).ServeHTTP(rr, r) + resp := rr.Result() - if got, want := rr.Code, http.StatusForbidden; got != want { + if got, want := resp.StatusCode, http.StatusForbidden; got != want { t.Fatalf("bad status: got %v want %v", got, want) } } func TestCORSHandlerMaxAgeForPreflight(t *testing.T) { - r := newRequest("OPTIONS", "http://www.example.com/") + r := newRequest(http.MethodOptions, "http://www.example.com/") r.Header.Set("Origin", r.URL.String()) - r.Header.Set(corsRequestMethodHeader, "POST") + r.Header.Set(corsRequestMethodHeader, http.MethodPost) rr := httptest.NewRecorder() testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) CORS(MaxAge(3500))(testHandler).ServeHTTP(rr, r) + resp := rr.Result() - if got, want := rr.Code, http.StatusOK; got != want { + if got, want := resp.StatusCode, http.StatusOK; got != want { t.Fatalf("bad status: got %v want %v", got, want) } - header := rr.HeaderMap.Get(corsMaxAgeHeader) + header := resp.Header.Get(corsMaxAgeHeader) if got, want := header, "600"; got != want { t.Fatalf("bad header: expected %q to be %q, got %q.", corsMaxAgeHeader, want, got) } } func TestCORSHandlerAllowedCredentials(t *testing.T) { - r := newRequest("GET", "http://www.example.com/") + r := newRequest(http.MethodGet, "http://www.example.com/") r.Header.Set("Origin", r.URL.String()) rr := httptest.NewRecorder() @@ -296,19 +307,20 @@ func TestCORSHandlerAllowedCredentials(t *testing.T) { testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) CORS(AllowCredentials())(testHandler).ServeHTTP(rr, r) + resp := rr.Result() - if status := rr.Code; status != http.StatusOK { + if status := resp.StatusCode; status != http.StatusOK { t.Fatalf("bad status: got %v want %v", status, http.StatusOK) } - header := rr.HeaderMap.Get(corsAllowCredentialsHeader) + header := resp.Header.Get(corsAllowCredentialsHeader) if got, want := header, "true"; got != want { t.Fatalf("bad header: expected %q to be %q, got %q.", corsAllowCredentialsHeader, want, got) } } func TestCORSHandlerMultipleAllowOriginsSetsVaryHeader(t *testing.T) { - r := newRequest("GET", "http://www.example.com/") + r := newRequest(http.MethodGet, "http://www.example.com/") r.Header.Set("Origin", r.URL.String()) rr := httptest.NewRecorder() @@ -316,12 +328,13 @@ func TestCORSHandlerMultipleAllowOriginsSetsVaryHeader(t *testing.T) { testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) CORS(AllowedOrigins([]string{r.URL.String(), "http://google.com"}))(testHandler).ServeHTTP(rr, r) + resp := rr.Result() - if status := rr.Code; status != http.StatusOK { + if status := resp.StatusCode; status != http.StatusOK { t.Fatalf("bad status: got %v want %v", status, http.StatusOK) } - header := rr.HeaderMap.Get(corsVaryHeader) + header := resp.Header.Get(corsVaryHeader) if got, want := header, corsOriginHeader; got != want { t.Fatalf("bad header: expected %s to be %q, got %q.", corsVaryHeader, want, got) } @@ -338,7 +351,7 @@ func TestCORSWithMultipleHandlers(t *testing.T) { lastHandledBy = "testHandler2" }) - r1 := newRequest("GET", "http://www.example.com/") + r1 := newRequest(http.MethodGet, "http://www.example.com/") rr1 := httptest.NewRecorder() handler1 := corsMiddleware(testHandler1) @@ -351,59 +364,56 @@ func TestCORSWithMultipleHandlers(t *testing.T) { } func TestCORSOriginValidatorWithImplicitStar(t *testing.T) { - r := newRequest("GET", "http://a.example.com") + r := newRequest(http.MethodGet, "http://a.example.com") r.Header.Set("Origin", r.URL.String()) rr := httptest.NewRecorder() testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) originValidator := func(origin string) bool { - if strings.HasSuffix(origin, ".example.com") { - return true - } - return false + return strings.HasSuffix(origin, ".example.com") } CORS(AllowedOriginValidator(originValidator))(testHandler).ServeHTTP(rr, r) - header := rr.HeaderMap.Get(corsAllowOriginHeader) + resp := rr.Result() + header := resp.Header.Get(corsAllowOriginHeader) if got, want := header, r.URL.String(); got != want { t.Fatalf("bad header: expected %s to be %q, got %q.", corsAllowOriginHeader, want, got) } } func TestCORSOriginValidatorWithExplicitStar(t *testing.T) { - r := newRequest("GET", "http://a.example.com") + r := newRequest(http.MethodGet, "http://a.example.com") r.Header.Set("Origin", r.URL.String()) rr := httptest.NewRecorder() testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) originValidator := func(origin string) bool { - if strings.HasSuffix(origin, ".example.com") { - return true - } - return false + return strings.HasSuffix(origin, ".example.com") } CORS( AllowedOriginValidator(originValidator), AllowedOrigins([]string{"*"}), )(testHandler).ServeHTTP(rr, r) - header := rr.HeaderMap.Get(corsAllowOriginHeader) + resp := rr.Result() + header := resp.Header.Get(corsAllowOriginHeader) if got, want := header, "*"; got != want { t.Fatalf("bad header: expected %q to be %q, got %q.", corsAllowOriginHeader, want, got) } } func TestCORSAllowStar(t *testing.T) { - r := newRequest("GET", "http://a.example.com") + r := newRequest(http.MethodGet, "http://a.example.com") r.Header.Set("Origin", r.URL.String()) rr := httptest.NewRecorder() testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) CORS()(testHandler).ServeHTTP(rr, r) - header := rr.HeaderMap.Get(corsAllowOriginHeader) + resp := rr.Result() + header := resp.Header.Get(corsAllowOriginHeader) if got, want := header, "*"; got != want { t.Fatalf("bad header: expected %q to be %q, got %q.", corsAllowOriginHeader, want, got) } diff --git a/go.mod b/go.mod index 58e6a85..79424f0 100644 --- a/go.mod +++ b/go.mod @@ -1,5 +1,5 @@ module github.com/gorilla/handlers -go 1.14 +go 1.19 -require github.com/felixge/httpsnoop v1.0.1 +require github.com/felixge/httpsnoop v1.0.3 diff --git a/go.sum b/go.sum index 8c26458..e05d02e 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,2 @@ -github.com/felixge/httpsnoop v1.0.1 h1:lvB5Jl89CsZtGIWuTcDM1E/vkVs49/Ml7JJe07l8SPQ= -github.com/felixge/httpsnoop v1.0.1/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= +github.com/felixge/httpsnoop v1.0.3 h1:s/nj+GCswXYzN5v2DpNMuMQYe+0DDwt5WVCU6CWBdXk= +github.com/felixge/httpsnoop v1.0.3/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= diff --git a/handlers.go b/handlers.go index 0509482..9b92fce 100644 --- a/handlers.go +++ b/handlers.go @@ -35,7 +35,7 @@ func (h MethodHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { } sort.Strings(allow) w.Header().Set("Allow", strings.Join(allow, ", ")) - if req.Method == "OPTIONS" { + if req.Method == http.MethodOptions { w.WriteHeader(http.StatusOK) } else { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) @@ -44,7 +44,7 @@ func (h MethodHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { } // responseLogger is wrapper of http.ResponseWriter that keeps track of its HTTP -// status code and body size +// status code and body size. type responseLogger struct { w http.ResponseWriter status int @@ -97,7 +97,7 @@ func isContentType(h http.Header, contentType string) bool { // Only PUT, POST, and PATCH requests are considered. func ContentTypeHandler(h http.Handler, contentTypes ...string) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if !(r.Method == "PUT" || r.Method == "POST" || r.Method == "PATCH") { + if !(r.Method == http.MethodPut || r.Method == http.MethodPost || r.Method == http.MethodPatch) { h.ServeHTTP(w, r) return } @@ -108,7 +108,10 @@ func ContentTypeHandler(h http.Handler, contentTypes ...string) http.Handler { return } } - http.Error(w, fmt.Sprintf("Unsupported content type %q; expected one of %q", r.Header.Get("Content-Type"), contentTypes), http.StatusUnsupportedMediaType) + http.Error(w, fmt.Sprintf("Unsupported content type %q; expected one of %q", + r.Header.Get("Content-Type"), + contentTypes), + http.StatusUnsupportedMediaType) }) } @@ -133,12 +136,12 @@ const ( // Form method takes precedence over header method. func HTTPMethodOverrideHandler(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method == "POST" { + if r.Method == http.MethodPost { om := r.FormValue(HTTPMethodOverrideFormKey) if om == "" { om = r.Header.Get(HTTPMethodOverrideHeader) } - if om == "PUT" || om == "PATCH" || om == "DELETE" { + if om == http.MethodPut || om == http.MethodPatch || om == http.MethodDelete { r.Method = om } } diff --git a/handlers_go18_test.go b/handlers_go18_test.go deleted file mode 100644 index d8e6321..0000000 --- a/handlers_go18_test.go +++ /dev/null @@ -1,43 +0,0 @@ -// +build go1.8 - -package handlers - -import ( - "io/ioutil" - "net/http" - "net/http/httptest" - "testing" -) - -// *httptest.ResponseRecorder doesn't implement Pusher, so wrap it. -type pushRecorder struct { - *httptest.ResponseRecorder -} - -func (pr pushRecorder) Push(target string, opts *http.PushOptions) error { - return nil -} - -func TestLoggingHandlerWithPush(t *testing.T) { - handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - if _, ok := w.(http.Pusher); !ok { - t.Fatalf("%T from LoggingHandler does not satisfy http.Pusher interface when built with Go >=1.8", w) - } - w.WriteHeader(200) - }) - - logger := LoggingHandler(ioutil.Discard, handler) - logger.ServeHTTP(pushRecorder{httptest.NewRecorder()}, newRequest("GET", "/")) -} - -func TestCombinedLoggingHandlerWithPush(t *testing.T) { - handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - if _, ok := w.(http.Pusher); !ok { - t.Fatalf("%T from CombinedLoggingHandler does not satisfy http.Pusher interface when built with Go >=1.8", w) - } - w.WriteHeader(200) - }) - - logger := CombinedLoggingHandler(ioutil.Discard, handler) - logger.ServeHTTP(pushRecorder{httptest.NewRecorder()}, newRequest("GET", "/")) -} diff --git a/handlers_test.go b/handlers_test.go index 3f164a3..ad3c54e 100644 --- a/handlers_test.go +++ b/handlers_test.go @@ -5,6 +5,8 @@ package handlers import ( + "io" + "log" "net/http" "net/http/httptest" "net/url" @@ -18,7 +20,10 @@ const ( ) var okHandler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - w.Write([]byte(ok)) + _, err := w.Write([]byte(ok)) + if err != nil { + log.Fatalf("error on writing to http.ResponseWriter: %v", err) + } }) func newRequest(method, url string) *http.Request { @@ -38,33 +43,39 @@ func TestMethodHandler(t *testing.T) { body string }{ // No handlers - {newRequest("GET", "/foo"), MethodHandler{}, http.StatusMethodNotAllowed, "", notAllowed}, - {newRequest("OPTIONS", "/foo"), MethodHandler{}, http.StatusOK, "", ""}, + {newRequest(http.MethodGet, "/foo"), MethodHandler{}, http.StatusMethodNotAllowed, "", notAllowed}, + {newRequest(http.MethodOptions, "/foo"), MethodHandler{}, http.StatusOK, "", ""}, // A single handler - {newRequest("GET", "/foo"), MethodHandler{"GET": okHandler}, http.StatusOK, "", ok}, - {newRequest("POST", "/foo"), MethodHandler{"GET": okHandler}, http.StatusMethodNotAllowed, "GET", notAllowed}, + {newRequest(http.MethodGet, "/foo"), MethodHandler{http.MethodGet: okHandler}, http.StatusOK, "", ok}, + {newRequest(http.MethodPost, "/foo"), MethodHandler{http.MethodGet: okHandler}, http.StatusMethodNotAllowed, http.MethodGet, notAllowed}, // Multiple handlers - {newRequest("GET", "/foo"), MethodHandler{"GET": okHandler, "POST": okHandler}, http.StatusOK, "", ok}, - {newRequest("POST", "/foo"), MethodHandler{"GET": okHandler, "POST": okHandler}, http.StatusOK, "", ok}, - {newRequest("DELETE", "/foo"), MethodHandler{"GET": okHandler, "POST": okHandler}, http.StatusMethodNotAllowed, "GET, POST", notAllowed}, - {newRequest("OPTIONS", "/foo"), MethodHandler{"GET": okHandler, "POST": okHandler}, http.StatusOK, "GET, POST", ""}, + {newRequest(http.MethodGet, "/foo"), MethodHandler{http.MethodGet: okHandler, http.MethodPost: okHandler}, http.StatusOK, "", ok}, + {newRequest(http.MethodPost, "/foo"), MethodHandler{http.MethodGet: okHandler, http.MethodPost: okHandler}, http.StatusOK, "", ok}, + {newRequest(http.MethodDelete, "/foo"), MethodHandler{http.MethodGet: okHandler, http.MethodPost: okHandler}, http.StatusMethodNotAllowed, "GET, POST", notAllowed}, + {newRequest(http.MethodOptions, "/foo"), MethodHandler{http.MethodGet: okHandler, http.MethodPost: okHandler}, http.StatusOK, "GET, POST", ""}, // Override OPTIONS - {newRequest("OPTIONS", "/foo"), MethodHandler{"OPTIONS": okHandler}, http.StatusOK, "", ok}, + {newRequest(http.MethodOptions, "/foo"), MethodHandler{http.MethodOptions: okHandler}, http.StatusOK, "", ok}, } for i, test := range tests { rec := httptest.NewRecorder() test.handler.ServeHTTP(rec, test.req) - if rec.Code != test.code { - t.Fatalf("%d: wrong code, got %d want %d", i, rec.Code, test.code) + resp := rec.Result() + if resp.StatusCode != test.code { + t.Fatalf("%d: wrong code, got %d want %d", i, resp.StatusCode, test.code) } - if allow := rec.HeaderMap.Get("Allow"); allow != test.allow { + if allow := resp.Header.Get("Allow"); allow != test.allow { t.Fatalf("%d: wrong Allow, got %s want %s", i, allow, test.allow) } - if body := rec.Body.String(); body != test.body { + + respBodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + t.Errorf("io error while reading response body %v", err) + } + if body := string(respBodyBytes); body != test.body { t.Fatalf("%d: wrong body, got %q want %q", i, body, test.body) } } @@ -77,13 +88,13 @@ func TestContentTypeHandler(t *testing.T) { ContentType string Code int }{ - {"POST", []string{"application/json"}, "application/json", http.StatusOK}, - {"POST", []string{"application/json", "application/xml"}, "application/json", http.StatusOK}, - {"POST", []string{"application/json"}, "application/json; charset=utf-8", http.StatusOK}, - {"POST", []string{"application/json"}, "application/json+xxx", http.StatusUnsupportedMediaType}, - {"POST", []string{"application/json"}, "text/plain", http.StatusUnsupportedMediaType}, - {"GET", []string{"application/json"}, "", http.StatusOK}, - {"GET", []string{}, "", http.StatusOK}, + {http.MethodPost, []string{"application/json"}, "application/json", http.StatusOK}, + {http.MethodPost, []string{"application/json", "application/xml"}, "application/json", http.StatusOK}, + {http.MethodPost, []string{"application/json"}, "application/json; charset=utf-8", http.StatusOK}, + {http.MethodPost, []string{"application/json"}, "application/json+xxx", http.StatusUnsupportedMediaType}, + {http.MethodPost, []string{"application/json"}, "text/plain", http.StatusUnsupportedMediaType}, + {http.MethodGet, []string{"application/json"}, "", http.StatusOK}, + {http.MethodGet, []string{}, "", http.StatusOK}, } for _, test := range tests { r, err := http.NewRequest(test.Method, "/", nil) @@ -103,19 +114,19 @@ func TestContentTypeHandler(t *testing.T) { } func TestHTTPMethodOverride(t *testing.T) { - var tests = []struct { + tests := []struct { Method string OverrideMethod string ExpectedMethod string }{ - {"POST", "PUT", "PUT"}, - {"POST", "PATCH", "PATCH"}, - {"POST", "DELETE", "DELETE"}, - {"PUT", "DELETE", "PUT"}, - {"GET", "GET", "GET"}, - {"HEAD", "HEAD", "HEAD"}, - {"GET", "PUT", "GET"}, - {"HEAD", "DELETE", "HEAD"}, + {http.MethodPost, http.MethodPut, http.MethodPut}, + {http.MethodPost, http.MethodPatch, http.MethodPatch}, + {http.MethodPost, http.MethodDelete, http.MethodDelete}, + {http.MethodPut, http.MethodDelete, http.MethodPut}, + {http.MethodGet, http.MethodGet, http.MethodGet}, + {http.MethodHead, http.MethodHead, http.MethodHead}, + {http.MethodGet, http.MethodPut, http.MethodGet}, + {http.MethodHead, http.MethodDelete, http.MethodHead}, } for _, test := range tests { diff --git a/logging.go b/logging.go index 228465e..2badb6f 100644 --- a/logging.go +++ b/logging.go @@ -18,7 +18,7 @@ import ( // Logging -// LogFormatterParams is the structure any formatter will be handed when time to log comes +// LogFormatterParams is the structure any formatter will be handed when time to log comes. type LogFormatterParams struct { Request *http.Request URL url.URL @@ -27,7 +27,7 @@ type LogFormatterParams struct { Size int } -// LogFormatter gives the signature of the formatter function passed to CustomLoggingHandler +// LogFormatter gives the signature of the formatter function passed to CustomLoggingHandler. type LogFormatter func(writer io.Writer, params LogFormatterParams) // loggingHandler is the http.Handler implementation for LoggingHandlerTo and its @@ -46,7 +46,10 @@ func (h loggingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { h.handler.ServeHTTP(w, req) if req.MultipartForm != nil { - req.MultipartForm.RemoveAll() + err := req.MultipartForm.RemoveAll() + if err != nil { + return + } } params := LogFormatterParams{ @@ -76,7 +79,7 @@ const lowerhex = "0123456789abcdef" func appendQuoted(buf []byte, s string) []byte { var runeTmp [utf8.UTFMax]byte - for width := 0; len(s) > 0; s = s[width:] { + for width := 0; len(s) > 0; s = s[width:] { //nolint: wastedassign //TODO: why width starts from 0and reassigned as 1 r := rune(s[0]) width = 1 if r >= utf8.RuneSelf { @@ -191,7 +194,7 @@ func buildCommonLogLine(req *http.Request, url url.URL, ts time.Time, status int func writeLog(writer io.Writer, params LogFormatterParams) { buf := buildCommonLogLine(params.Request, params.URL, params.TimeStamp, params.StatusCode, params.Size) buf = append(buf, '\n') - writer.Write(buf) + _, _ = writer.Write(buf) } // writeCombinedLog writes a log entry for req to w in Apache Combined Log Format. @@ -204,7 +207,7 @@ func writeCombinedLog(writer io.Writer, params LogFormatterParams) { buf = append(buf, `" "`...) buf = appendQuoted(buf, params.Request.UserAgent()) buf = append(buf, '"', '\n') - writer.Write(buf) + _, _ = writer.Write(buf) } // CombinedLoggingHandler return a http.Handler that wraps h and logs requests to out in @@ -212,7 +215,7 @@ func writeCombinedLog(writer io.Writer, params LogFormatterParams) { // // See http://httpd.apache.org/docs/2.2/logs.html#combined for a description of this format. // -// LoggingHandler always sets the ident field of the log to - +// LoggingHandler always sets the ident field of the log to -. func CombinedLoggingHandler(out io.Writer, h http.Handler) http.Handler { return loggingHandler{out, h, writeCombinedLog} } @@ -226,19 +229,18 @@ func CombinedLoggingHandler(out io.Writer, h http.Handler) http.Handler { // // Example: // -// r := mux.NewRouter() -// r.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { -// w.Write([]byte("This is a catch-all route")) -// }) -// loggedRouter := handlers.LoggingHandler(os.Stdout, r) -// http.ListenAndServe(":1123", loggedRouter) -// +// r := mux.NewRouter() +// r.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { +// w.Write([]byte("This is a catch-all route")) +// }) +// loggedRouter := handlers.LoggingHandler(os.Stdout, r) +// http.ListenAndServe(":1123", loggedRouter) func LoggingHandler(out io.Writer, h http.Handler) http.Handler { return loggingHandler{out, h, writeLog} } // CustomLoggingHandler provides a way to supply a custom log formatter -// while taking advantage of the mechanisms in this package +// while taking advantage of the mechanisms in this package. func CustomLoggingHandler(out io.Writer, h http.Handler, f LogFormatter) http.Handler { return loggingHandler{out, h, f} } diff --git a/logging_test.go b/logging_test.go index 13b8369..4a89cfd 100644 --- a/logging_test.go +++ b/logging_test.go @@ -6,10 +6,11 @@ package handlers import ( "bytes" + "crypto/rand" "encoding/base64" + "errors" "fmt" - "io/ioutil" - "math/rand" + "io/fs" "mime/multipart" "net/http" "net/http/httptest" @@ -34,7 +35,11 @@ func TestMakeLogger(t *testing.T) { t.Fatalf("wrong status, got %d want %d", logger.Status(), http.StatusInternalServerError) } // Write - w.Write([]byte(ok)) + _, err := w.Write([]byte(ok)) + if err != nil { + t.Fatalf("error while writing to http.ResponseWriter %v", err) + return + } if logger.Size() != len(ok) { t.Fatalf("wrong size, got %d want %d", logger.Size(), len(ok)) } @@ -46,8 +51,6 @@ func TestMakeLogger(t *testing.T) { } func TestLoggerCleanup(t *testing.T) { - rand.Seed(time.Now().UnixNano()) - rbuf := make([]byte, 128) if _, err := rand.Read(rbuf); err != nil { t.Fatalf("Failed to generate random content: %v", err) @@ -68,7 +71,7 @@ Content-Disposition: form-data; name="buzz"; filename="example.txt" t.Fatalf("Failed to read multipart form: %v", err) } - tmpFiles, err := ioutil.ReadDir(os.TempDir()) + tmpFiles, err := os.ReadDir(os.TempDir()) if err != nil { t.Fatalf("Failed to list %s: %v", os.TempDir(), err) } @@ -80,14 +83,13 @@ Content-Disposition: form-data; name="buzz"; filename="example.txt" } path := filepath.Join(os.TempDir(), f.Name()) - switch b, err := ioutil.ReadFile(path); { - case err != nil: + switch b, fileError := os.ReadFile(path); { + case fileError != nil: t.Fatalf("Failed to read %s: %v", path, err) case string(b) != contents: continue default: tmpFile = path - break } } @@ -95,19 +97,19 @@ Content-Disposition: form-data; name="buzz"; filename="example.txt" t.Fatal("Could not find multipart form tmp file") } - req := newRequest("GET", "/subdir/asdf") + req := newRequest(http.MethodGet, "/subdir/asdf") req.MultipartForm = form var buf bytes.Buffer handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req.URL.Path = "/" // simulate http.StripPrefix and friends - w.WriteHeader(200) + w.WriteHeader(http.StatusOK) }) logger := LoggingHandler(&buf, handler) logger.ServeHTTP(httptest.NewRecorder(), req) - if _, err := os.Stat(tmpFile); err == nil || !os.IsNotExist(err) { - t.Fatalf("Expected %s to not exist, got %v", tmpFile, err) + if _, osStatErr := os.Stat(tmpFile); osStatErr == nil || !errors.Is(osStatErr, fs.ErrNotExist) { + t.Fatalf("Expected %s to not exist, got %v", tmpFile, osStatErr) } } @@ -116,11 +118,11 @@ func TestLogPathRewrites(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req.URL.Path = "/" // simulate http.StripPrefix and friends - w.WriteHeader(200) + w.WriteHeader(http.StatusOK) }) logger := LoggingHandler(&buf, handler) - logger.ServeHTTP(httptest.NewRecorder(), newRequest("GET", "/subdir/asdf")) + logger.ServeHTTP(httptest.NewRecorder(), newRequest(http.MethodGet, "/subdir/asdf")) if !strings.Contains(buf.String(), "GET /subdir/asdf HTTP") { t.Fatalf("Got log %#v, wanted substring %#v", buf.String(), "GET /subdir/asdf HTTP") @@ -132,9 +134,9 @@ func BenchmarkWriteLog(b *testing.B) { if err != nil { b.Fatalf(err.Error()) } - ts := time.Date(1983, 05, 26, 3, 30, 45, 0, loc) + ts := time.Date(1983, 0o5, 26, 3, 30, 45, 0, loc) - req := newRequest("GET", "http://example.com") + req := newRequest(http.MethodGet, "http://example.com") req.RemoteAddr = "192.168.100.5" b.ResetTimer() @@ -216,7 +218,7 @@ func LoggingScenario1(t *testing.T, formatter LogFormatter, expected string) { if err != nil { panic(err) } - ts := time.Date(1983, 05, 26, 3, 30, 45, 0, loc) + ts := time.Date(1983, 0o5, 26, 3, 30, 45, 0, loc) // A typical request with an OK response req := constructTypicalRequestOk() @@ -243,7 +245,7 @@ func LoggingScenario2(t *testing.T, formatter LogFormatter, expected string) { if err != nil { panic(err) } - ts := time.Date(1983, 05, 26, 3, 30, 45, 0, loc) + ts := time.Date(1983, 0o5, 26, 3, 30, 45, 0, loc) // CONNECT request over http/2.0 req := constructConnectRequest() @@ -269,7 +271,7 @@ func LoggingScenario3(t *testing.T, formatter LogFormatter, expected string) { if err != nil { panic(err) } - ts := time.Date(1983, 05, 26, 3, 30, 45, 0, loc) + ts := time.Date(1983, 0o5, 26, 3, 30, 45, 0, loc) // Request with an unauthorized user req := constructTypicalRequestOk() @@ -296,7 +298,7 @@ func LoggingScenario4(t *testing.T, formatter LogFormatter, expected string) { if err != nil { panic(err) } - ts := time.Date(1983, 05, 26, 3, 30, 45, 0, loc) + ts := time.Date(1983, 0o5, 26, 3, 30, 45, 0, loc) // Request with url encoded parameters req := constructEncodedRequest() @@ -322,7 +324,7 @@ func LoggingScenario5(t *testing.T, formatter LogFormatter, expected string) { if err != nil { panic(err) } - ts := time.Date(1983, 05, 26, 3, 30, 45, 0, loc) + ts := time.Date(1983, 0o5, 26, 3, 30, 45, 0, loc) req := constructTypicalRequestOk() req.URL.User = url.User("kamil") @@ -344,9 +346,9 @@ func LoggingScenario5(t *testing.T, formatter LogFormatter, expected string) { } } -// A typical request with an OK response +// A typical request with an OK response. func constructTypicalRequestOk() *http.Request { - req := newRequest("GET", "http://example.com") + req := newRequest(http.MethodGet, "http://example.com") req.RemoteAddr = "192.168.100.5" req.Header.Set("Referer", "http://example.com") req.Header.Set( @@ -357,10 +359,10 @@ func constructTypicalRequestOk() *http.Request { return req } -// CONNECT request over http/2.0 +// CONNECT request over http/2.0. func constructConnectRequest() *http.Request { req := &http.Request{ - Method: "CONNECT", + Method: http.MethodConnect, Host: "www.example.com:443", Proto: "HTTP/2.0", ProtoMajor: 2, diff --git a/proxy_headers.go b/proxy_headers.go index ed939dc..281d753 100644 --- a/proxy_headers.go +++ b/proxy_headers.go @@ -18,7 +18,7 @@ var ( var ( // RFC7239 defines a new "Forwarded: " header designed to replace the // existing use of X-Forwarded-* headers. - // e.g. Forwarded: for=192.0.2.60;proto=https;by=203.0.113.43 + // e.g. Forwarded: for=192.0.2.60;proto=https;by=203.0.113.43. forwarded = http.CanonicalHeaderKey("Forwarded") // Allows for a sub-match of the first value after 'for=' to the next // comma, semi-colon or space. The match is case-insensitive. @@ -67,7 +67,9 @@ func ProxyHeaders(h http.Handler) http.Handler { func getIP(r *http.Request) string { var addr string - if fwd := r.Header.Get(xForwardedFor); fwd != "" { + switch { + case r.Header.Get(xForwardedFor) != "": + fwd := r.Header.Get(xForwardedFor) // Only grab the first (client) address. Note that '192.168.0.1, // 10.1.1.1' is a valid key for X-Forwarded-For where addresses after // the first may represent forwarding proxies earlier in the chain. @@ -76,17 +78,15 @@ func getIP(r *http.Request) string { s = len(fwd) } addr = fwd[:s] - } else if fwd := r.Header.Get(xRealIP); fwd != "" { - // X-Real-IP should only contain one IP address (the client making the - // request). - addr = fwd - } else if fwd := r.Header.Get(forwarded); fwd != "" { + case r.Header.Get(xRealIP) != "": + addr = r.Header.Get(xRealIP) + case r.Header.Get(forwarded) != "": // match should contain at least two elements if the protocol was // specified in the Forwarded header. The first element will always be // the 'for=' capture, which we ignore. In the case of multiple IP // addresses (for=8.8.8.8, 8.8.4.4,172.16.1.20 is valid) we only // extract the first, which should be the client IP. - if match := forRegex.FindStringSubmatch(fwd); len(match) > 1 { + if match := forRegex.FindStringSubmatch(r.Header.Get(forwarded)); len(match) > 1 { // IPv6 addresses in Forwarded headers are quoted-strings. We strip // these quotes. addr = strings.Trim(match[1], `"`) diff --git a/proxy_headers_test.go b/proxy_headers_test.go index b1fe685..620d387 100644 --- a/proxy_headers_test.go +++ b/proxy_headers_test.go @@ -33,7 +33,8 @@ func TestGetIP(t *testing.T) { req := &http.Request{ Header: http.Header{ v.key: []string{v.val}, - }} + }, + } res := getIP(req) if res != v.expected { t.Fatalf("wrong header for %s: got %s want %s", v.key, res, @@ -70,10 +71,10 @@ func TestGetScheme(t *testing.T) { } } -// Test the middleware end-to-end +// Test the middleware end-to-end. func TestProxyHeaders(t *testing.T) { rr := httptest.NewRecorder() - r := newRequest("GET", "/") + r := newRequest(http.MethodGet, "/") r.Header.Set(xForwardedFor, "8.8.8.8") r.Header.Set(xForwardedProto, "https") @@ -107,5 +108,4 @@ func TestProxyHeaders(t *testing.T) { t.Fatalf("wrong address: got %s want %s", host, r.Header.Get(xForwardedHost)) } - } diff --git a/recovery.go b/recovery.go index 4c4c1d9..0d4f955 100644 --- a/recovery.go +++ b/recovery.go @@ -36,12 +36,12 @@ func parseRecoveryOptions(h http.Handler, opts ...RecoveryOption) http.Handler { // // Example: // -// r := mux.NewRouter() -// r.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { -// panic("Unexpected error!") -// }) +// r := mux.NewRouter() +// r.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { +// panic("Unexpected error!") +// }) // -// http.ListenAndServe(":1123", handlers.RecoveryHandler()(r)) +// http.ListenAndServe(":1123", handlers.RecoveryHandler()(r)) func RecoveryHandler(opts ...RecoveryOption) func(h http.Handler) http.Handler { return func(h http.Handler) http.Handler { r := &recoveryHandler{handler: h} @@ -50,20 +50,22 @@ func RecoveryHandler(opts ...RecoveryOption) func(h http.Handler) http.Handler { } // RecoveryLogger is a functional option to override -// the default logger +// the default logger. func RecoveryLogger(logger RecoveryHandlerLogger) RecoveryOption { return func(h http.Handler) { - r := h.(*recoveryHandler) + r := h.(*recoveryHandler) //nolint:errcheck //TODO: + // @bharat-rajani should return type-assertion error but would break the API? r.logger = logger } } // PrintRecoveryStack is a functional option to enable // or disable printing stack traces on panic. -func PrintRecoveryStack(print bool) RecoveryOption { +func PrintRecoveryStack(shouldPrint bool) RecoveryOption { return func(h http.Handler) { - r := h.(*recoveryHandler) - r.printStack = print + r := h.(*recoveryHandler) //nolint:errcheck //TODO: + // @bharat-rajani should return type-assertion error but would break the API? + r.printStack = shouldPrint } } diff --git a/recovery_test.go b/recovery_test.go index f21e43d..0b6c5e4 100644 --- a/recovery_test.go +++ b/recovery_test.go @@ -19,7 +19,7 @@ func TestRecoveryLoggerWithDefaultOptions(t *testing.T) { }) recovery := handler(handlerFunc) - recovery.ServeHTTP(httptest.NewRecorder(), newRequest("GET", "/subdir/asdf")) + recovery.ServeHTTP(httptest.NewRecorder(), newRequest(http.MethodGet, "/subdir/asdf")) if !strings.Contains(buf.String(), "Unexpected error!") { t.Fatalf("Got log %#v, wanted substring %#v", buf.String(), "Unexpected error!") @@ -28,7 +28,7 @@ func TestRecoveryLoggerWithDefaultOptions(t *testing.T) { func TestRecoveryLoggerWithCustomLogger(t *testing.T) { var buf bytes.Buffer - var logger = log.New(&buf, "", log.LstdFlags) + logger := log.New(&buf, "", log.LstdFlags) handlerFunc := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { panic("Unexpected error!") @@ -38,7 +38,7 @@ func TestRecoveryLoggerWithCustomLogger(t *testing.T) { handler := RecoveryHandler(RecoveryLogger(logger), PrintRecoveryStack(false)) recovery := handler(handlerFunc) - recovery.ServeHTTP(httptest.NewRecorder(), newRequest("GET", "/subdir/asdf")) + recovery.ServeHTTP(httptest.NewRecorder(), newRequest(http.MethodGet, "/subdir/asdf")) if !strings.Contains(buf.String(), "Unexpected error!") { t.Fatalf("Got log %#v, wanted substring %#v", buf.String(), "Unexpected error!") @@ -49,7 +49,7 @@ func TestRecoveryLoggerWithCustomLogger(t *testing.T) { handler := RecoveryHandler(RecoveryLogger(logger), PrintRecoveryStack(true)) recovery := handler(handlerFunc) - recovery.ServeHTTP(httptest.NewRecorder(), newRequest("GET", "/subdir/asdf")) + recovery.ServeHTTP(httptest.NewRecorder(), newRequest(http.MethodGet, "/subdir/asdf")) if !strings.Contains(buf.String(), "runtime/debug.Stack") { t.Fatalf("Got log %#v, wanted substring %#v", buf.String(), "runtime/debug.Stack")