diff --git a/cors.go b/cors.go index dd9029a..6d1affe 100644 --- a/cors.go +++ b/cors.go @@ -37,14 +37,23 @@ type Options struct { // Only one wildcard can be used per origin. // Default value is ["*"] AllowedOrigins []string - // AllowOriginFunc is a custom function to validate the origin. It take the origin - // as argument and returns true if allowed or false otherwise. If this option is - // set, the content of AllowedOrigins is ignored. + // AllowOriginFunc is a custom function to validate the origin. It take the + // origin as argument and returns true if allowed or false otherwise. If + // this option is set, the content of `AllowedOrigins` is ignored. AllowOriginFunc func(origin string) bool - // AllowOriginRequestFunc is a custom function to validate the origin. It takes the HTTP Request object and the origin as - // argument and returns true if allowed or false otherwise. If this option is set, the content of `AllowedOrigins` - // and `AllowOriginFunc` is ignored. + // AllowOriginRequestFunc is a custom function to validate the origin. It + // takes the HTTP Request object and the origin as argument and returns true + // if allowed or false otherwise. If headers are used take the decision, + // consider using AllowOriginVaryRequestFunc instead. If this option is set, + // the content of `AllowedOrigins`, `AllowOriginFunc` are ignored. AllowOriginRequestFunc func(r *http.Request, origin string) bool + // AllowOriginVaryRequestFunc is a custom function to validate the origin. + // It takes the HTTP Request object and the origin as argument and returns + // true if allowed or false otherwise with a list of headers used to take + // that decision if any so they can be added to the Vary header. If this + // option is set, the content of `AllowedOrigins`, `AllowOriginFunc` and + // `AllowOriginRequestFunc` are ignored. + AllowOriginVaryRequestFunc func(r *http.Request, origin string) (bool, []string) // AllowedMethods is a list of methods the client is allowed to use with // cross-domain requests. Default value is simple methods (HEAD, GET and POST). AllowedMethods []string @@ -91,9 +100,7 @@ type Cors struct { // List of allowed origins containing wildcards allowedWOrigins []wildcard // Optional origin validator function - allowOriginFunc func(origin string) bool - // Optional origin validator (with request) function - allowOriginRequestFunc func(r *http.Request, origin string) bool + allowOriginFunc func(r *http.Request, origin string) (bool, []string) // Normalized list of allowed headers allowedHeaders []string // Normalized list of allowed methods @@ -115,26 +122,36 @@ type Cors struct { // New creates a new Cors handler with the provided options. func New(options Options) *Cors { c := &Cors{ - exposedHeaders: convert(options.ExposedHeaders, http.CanonicalHeaderKey), - allowOriginFunc: options.AllowOriginFunc, - allowOriginRequestFunc: options.AllowOriginRequestFunc, - allowCredentials: options.AllowCredentials, - allowPrivateNetwork: options.AllowPrivateNetwork, - maxAge: options.MaxAge, - optionPassthrough: options.OptionsPassthrough, - Log: options.Logger, + exposedHeaders: convert(options.ExposedHeaders, http.CanonicalHeaderKey), + allowCredentials: options.AllowCredentials, + allowPrivateNetwork: options.AllowPrivateNetwork, + maxAge: options.MaxAge, + optionPassthrough: options.OptionsPassthrough, + Log: options.Logger, } if options.Debug && c.Log == nil { c.Log = log.New(os.Stdout, "[cors] ", log.LstdFlags) } + if options.AllowOriginVaryRequestFunc != nil { + c.allowOriginFunc = options.AllowOriginVaryRequestFunc + } else if options.AllowOriginRequestFunc != nil { + c.allowOriginFunc = func(r *http.Request, origin string) (bool, []string) { + return options.AllowOriginRequestFunc(r, origin), nil + } + } else if options.AllowOriginFunc != nil { + c.allowOriginFunc = func(r *http.Request, origin string) (bool, []string) { + return options.AllowOriginFunc(origin), nil + } + } + // Normalize options // Note: for origins matching, the spec requires a case-sensitive matching. // As it may error prone, we chose to ignore the spec here. // Allowed Origins if len(options.AllowedOrigins) == 0 { - if options.AllowOriginFunc == nil && options.AllowOriginRequestFunc == nil { + if c.allowOriginFunc == nil { // Default is all origins c.allowedOriginsAll = true } @@ -294,11 +311,16 @@ func (c *Cors) handlePreflight(w http.ResponseWriter, r *http.Request) { headers.Add("Vary", "Access-Control-Request-Private-Network") } + allowed, additionalVaryHeaders := c.isOriginAllowed(r, origin) + if len(additionalVaryHeaders) > 0 { + headers.Add("Vary", strings.Join(convert(additionalVaryHeaders, http.CanonicalHeaderKey), ", ")) + } + if origin == "" { c.logf(" Preflight aborted: empty origin") return } - if !c.isOriginAllowed(r, origin) { + if !allowed { c.logf(" Preflight aborted: origin '%s' not allowed", origin) return } @@ -349,13 +371,18 @@ func (c *Cors) handleActualRequest(w http.ResponseWriter, r *http.Request) { headers := w.Header() origin := r.Header.Get("Origin") + allowed, additionalVaryHeaders := c.isOriginAllowed(r, origin) + // Always set Vary, see https://github.com/rs/cors/issues/10 headers.Add("Vary", "Origin") + if len(additionalVaryHeaders) > 0 { + headers.Add("Vary", strings.Join(convert(additionalVaryHeaders, http.CanonicalHeaderKey), ", ")) + } if origin == "" { c.logf(" Actual request no headers added: missing origin") return } - if !c.isOriginAllowed(r, origin) { + if !allowed { c.logf(" Actual request no headers added: origin '%s' not allowed", origin) return } @@ -366,7 +393,6 @@ func (c *Cors) handleActualRequest(w http.ResponseWriter, r *http.Request) { // We think it's a nice feature to be able to have control on those methods though. if !c.isMethodAllowed(r.Method) { c.logf(" Actual request no headers added: method '%s' not allowed", r.Method) - return } if c.allowedOriginsAll { @@ -393,33 +419,31 @@ func (c *Cors) logf(format string, a ...interface{}) { // check the Origin of a request. No origin at all is also allowed. func (c *Cors) OriginAllowed(r *http.Request) bool { origin := r.Header.Get("Origin") - return c.isOriginAllowed(r, origin) + allowed, _ := c.isOriginAllowed(r, origin) + return allowed } // isOriginAllowed checks if a given origin is allowed to perform cross-domain requests // on the endpoint -func (c *Cors) isOriginAllowed(r *http.Request, origin string) bool { - if c.allowOriginRequestFunc != nil { - return c.allowOriginRequestFunc(r, origin) - } +func (c *Cors) isOriginAllowed(r *http.Request, origin string) (allowed bool, varyHeaders []string) { if c.allowOriginFunc != nil { - return c.allowOriginFunc(origin) + return c.allowOriginFunc(r, origin) } if c.allowedOriginsAll { - return true + return true, nil } origin = strings.ToLower(origin) for _, o := range c.allowedOrigins { if o == origin { - return true + return true, nil } } for _, w := range c.allowedWOrigins { if w.match(origin) { - return true + return true, nil } } - return false + return false, nil } // isMethodAllowed checks if a given method can be used as part of a cross-domain request diff --git a/cors_test.go b/cors_test.go index 00a600e..3c05326 100644 --- a/cors_test.go +++ b/cors_test.go @@ -187,6 +187,24 @@ func TestSpec(t *testing.T) { }, true, }, + { + "AllowOriginVaryRequestFuncMatch", + Options{ + AllowOriginVaryRequestFunc: func(r *http.Request, o string) (bool, []string) { + return regexp.MustCompile("^http://foo").MatchString(o) && r.Header.Get("Authorization") == "secret", []string{"Authorization"} + }, + }, + "GET", + map[string]string{ + "Origin": "http://foobar.com", + "Authorization": "secret", + }, + map[string]string{ + "Vary": "Origin, Authorization", + "Access-Control-Allow-Origin": "http://foobar.com", + }, + true, + }, { "AllowOriginRequestFuncNotMatch", Options{