From 5e55a4adb89fb64e8b13bbd302eeedac7a4ba5d8 Mon Sep 17 00:00:00 2001 From: Franklin Harding <32021905+fharding1@users.noreply.github.com> Date: Fri, 11 May 2018 18:30:14 -0700 Subject: [PATCH] Add CORSMethodMiddleware (#366) CORSMethodMiddleware sets the Access-Control-Allow-Methods response header on a request, by matching routes based only on paths. It also handles OPTIONS requests, by settings Access-Control-Allow-Methods, and then returning without calling the next HTTP handler. --- middleware.go | 44 +++++++++++++++++++++++++++++++++++++++++++- middleware_test.go | 41 +++++++++++++++++++++++++++++++++++++++++ mux_test.go | 8 ++++++++ 3 files changed, 92 insertions(+), 1 deletion(-) diff --git a/middleware.go b/middleware.go index cf6cfc33..fab9ae35 100644 --- a/middleware.go +++ b/middleware.go @@ -1,6 +1,9 @@ package mux -import "net/http" +import ( + "net/http" + "strings" +) // MiddlewareFunc is a function which receives an http.Handler and returns another http.Handler. // Typically, the returned handler is a closure which does something with the http.ResponseWriter and http.Request passed @@ -28,3 +31,42 @@ func (r *Router) Use(mwf ...MiddlewareFunc) { func (r *Router) useInterface(mw middleware) { r.middlewares = append(r.middlewares, mw) } + +// CORSMethodMiddleware sets the Access-Control-Allow-Methods response header +// on a request, by matching routes based only on paths. It also handles +// OPTIONS requests, by settings Access-Control-Allow-Methods, and then +// returning without calling the next http handler. +func CORSMethodMiddleware(r *Router) MiddlewareFunc { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + var allMethods []string + + err := r.Walk(func(route *Route, _ *Router, _ []*Route) error { + for _, m := range route.matchers { + if _, ok := m.(*routeRegexp); ok { + if m.Match(req, &RouteMatch{}) { + methods, err := route.GetMethods() + if err != nil { + return err + } + + allMethods = append(allMethods, methods...) + } + break + } + } + return nil + }) + + if err == nil { + w.Header().Set("Access-Control-Allow-Methods", strings.Join(append(allMethods, "OPTIONS"), ",")) + + if req.Method == "OPTIONS" { + return + } + } + + next.ServeHTTP(w, req) + }) + } +} diff --git a/middleware_test.go b/middleware_test.go index 93947e8c..acf4e160 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -3,6 +3,7 @@ package mux import ( "bytes" "net/http" + "net/http/httptest" "testing" ) @@ -334,3 +335,43 @@ func TestMiddlewareMethodMismatchSubrouter(t *testing.T) { t.Fatal("Middleware was called for a method mismatch") } } + +func TestCORSMethodMiddleware(t *testing.T) { + router := NewRouter() + + cases := []struct { + path string + response string + method string + testURL string + expectedAllowedMethods string + }{ + {"/g/{o}", "a", "POST", "/g/asdf", "POST,PUT,GET,OPTIONS"}, + {"/g/{o}", "b", "PUT", "/g/bla", "POST,PUT,GET,OPTIONS"}, + {"/g/{o}", "c", "GET", "/g/orilla", "POST,PUT,GET,OPTIONS"}, + {"/g", "d", "POST", "/g", "POST,OPTIONS"}, + } + + for _, tt := range cases { + router.HandleFunc(tt.path, stringHandler(tt.response)).Methods(tt.method) + } + + router.Use(CORSMethodMiddleware(router)) + + for _, tt := range cases { + rr := httptest.NewRecorder() + req := newRequest(tt.method, tt.testURL) + + router.ServeHTTP(rr, req) + + if rr.Body.String() != tt.response { + t.Errorf("Expected body '%s', found '%s'", tt.response, rr.Body.String()) + } + + allowedMethods := rr.HeaderMap.Get("Access-Control-Allow-Methods") + + if allowedMethods != tt.expectedAllowedMethods { + t.Errorf("Expected Access-Control-Allow-Methods '%s', found '%s'", tt.expectedAllowedMethods, allowedMethods) + } + } +} diff --git a/mux_test.go b/mux_test.go index 4591344e..af21329f 100644 --- a/mux_test.go +++ b/mux_test.go @@ -2315,6 +2315,14 @@ func stringMapEqual(m1, m2 map[string]string) bool { return true } +// stringHandler returns a handler func that writes a message 's' to the +// http.ResponseWriter. +func stringHandler(s string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(s)) + } +} + // newRequest is a helper function to create a new request with a method and url. // The request returned is a 'server' request as opposed to a 'client' one through // simulated write onto the wire and read off of the wire.