From f2fe3ba4e9de76e91672fb65431218b453c911cd Mon Sep 17 00:00:00 2001 From: commit-master Date: Tue, 28 Aug 2018 17:31:04 +0200 Subject: [PATCH] [cors] add OPTIONS status code + fix function typo --- cors.go | 26 ++++++++++++++++++++------ cors_test.go | 19 +++++++++++++++++++ 2 files changed, 39 insertions(+), 6 deletions(-) diff --git a/cors.go b/cors.go index 1acf80d..52099a0 100644 --- a/cors.go +++ b/cors.go @@ -19,14 +19,16 @@ type cors struct { maxAge int ignoreOptions bool allowCredentials bool + optionStatusCode int } // OriginValidator takes an origin string and returns whether or not that origin is allowed. type OriginValidator func(string) bool var ( - defaultCorsMethods = []string{"GET", "HEAD", "POST"} - defaultCorsHeaders = []string{"Accept", "Accept-Language", "Content-Language", "Origin"} + defaultCorsOptionStatusCode = 200 + defaultCorsMethods = []string{"GET", "HEAD", "POST"} + defaultCorsHeaders = []string{"Accept", "Accept-Language", "Content-Language", "Origin"} // (WebKit/Safari v9 sends the Origin header by default in AJAX requests) ) @@ -130,6 +132,7 @@ func (ch *cors) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.Header().Set(corsAllowOriginHeader, returnOrigin) if r.Method == corsOptionMethod { + w.WriteHeader(ch.optionStatusCode) return } ch.h.ServeHTTP(w, r) @@ -164,9 +167,10 @@ func CORS(opts ...CORSOption) func(http.Handler) http.Handler { func parseCORSOptions(opts ...CORSOption) *cors { ch := &cors{ - allowedMethods: defaultCorsMethods, - allowedHeaders: defaultCorsHeaders, - allowedOrigins: []string{}, + allowedMethods: defaultCorsMethods, + allowedHeaders: defaultCorsHeaders, + allowedOrigins: []string{}, + optionStatusCode: defaultCorsOptionStatusCode, } for _, option := range opts { @@ -251,7 +255,17 @@ func AllowedOriginValidator(fn OriginValidator) CORSOption { } } -// ExposeHeaders can be used to specify headers that are available +// OptionStatusCode sets a custom status code to use to respond to OPTIONS requests +// by default it is set to 200 to ensure max compatibility, see: +// https://stackoverflow.com/questions/46026409/what-are-proper-status-codes-for-cors-preflight-requests +func OptionStatusCode(code int) CORSOption { + return func(ch *cors) error { + ch.optionStatusCode = code + return nil + } +} + +// ExposedHeaders can be used to specify headers that are available // and will not be stripped out by the user-agent. func ExposedHeaders(headers []string) CORSOption { return func(ch *cors) error { diff --git a/cors_test.go b/cors_test.go index a797be7..ad50e51 100644 --- a/cors_test.go +++ b/cors_test.go @@ -122,6 +122,25 @@ func TestCORSHandlerOptionsRequestMustNotBePassedToNextHandler(t *testing.T) { } } +func TestCORSHandlerOptionsRequestMustNotBePassedToNextHandlerWithCustomStatusCode(t *testing.T) { + statusCode := 204 + r := newRequest("OPTIONS", "http://www.example.com/") + r.Header.Set("Origin", r.URL.String()) + r.Header.Set(corsRequestMethodHeader, "GET") + + rr := httptest.NewRecorder() + + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Fatal("Options request must not be passed to next handler") + }) + + CORS(OptionStatusCode(statusCode))(testHandler).ServeHTTP(rr, r) + + if status := rr.Code; status != statusCode { + t.Fatalf("bad status: got %v want %v", status, http.StatusOK) + } +} + func TestCORSHandlerOptionsRequestMustNotBePassedToNextHandlerWhenOriginNotAllowed(t *testing.T) { r := newRequest("OPTIONS", "http://www.example.com/") r.Header.Set("Origin", r.URL.String())