diff --git a/middleware/cors/cors.go b/middleware/cors/cors.go index 7debfdfaa0..8fd70b3c60 100644 --- a/middleware/cors/cors.go +++ b/middleware/cors/cors.go @@ -162,9 +162,19 @@ func New(config ...Config) fiber.Handler { // Get originHeader header originHeader := strings.ToLower(c.Get(fiber.HeaderOrigin)) - // If the request does not have Origin and Access-Control-Request-Method - // headers, the request is outside the scope of CORS - if originHeader == "" || c.Get(fiber.HeaderAccessControlRequestMethod) == "" { + // If the request does not have Origin header, the request is outside the scope of CORS + if originHeader == "" { + // See https://fetch.spec.whatwg.org/#cors-protocol-and-http-caches + // Unless all origins are allowed, we include the Vary header to cache the response correctly + if !allowAllOrigins { + c.Vary(fiber.HeaderOrigin) + } + + return c.Next() + } + + // If it's a preflight request and doesn't have Access-Control-Request-Method header, it's outside the scope of CORS + if c.Method() == fiber.MethodOptions && c.Get(fiber.HeaderAccessControlRequestMethod) == "" { return c.Next() } @@ -204,13 +214,23 @@ func New(config ...Config) fiber.Handler { // Simple request // Ommit allowMethods and allowHeaders, only used for pre-flight requests if c.Method() != fiber.MethodOptions { + if !allowAllOrigins { + // See https://fetch.spec.whatwg.org/#cors-protocol-and-http-caches + c.Vary(fiber.HeaderOrigin) + } setCORSHeaders(c, allowOrigin, "", "", exposeHeaders, maxAge, cfg) return c.Next() } - // Preflight request + // Pre-flight request + + // Response to OPTIONS request should not be cached but, + // some caching can be configured to cache such responses. + // To Avoid poisoning the cache, we include the Vary header + // of preflight responses: c.Vary(fiber.HeaderAccessControlRequestMethod) c.Vary(fiber.HeaderAccessControlRequestHeaders) + c.Vary(fiber.HeaderOrigin) setCORSHeaders(c, allowOrigin, allowMethods, allowHeaders, exposeHeaders, maxAge, cfg) @@ -221,8 +241,6 @@ func New(config ...Config) fiber.Handler { // Function to set CORS headers func setCORSHeaders(c *fiber.Ctx, allowOrigin, allowMethods, allowHeaders, exposeHeaders, maxAge string, cfg Config) { - c.Vary(fiber.HeaderOrigin) - if cfg.AllowCredentials { // When AllowCredentials is true, set the Access-Control-Allow-Origin to the specific origin instead of '*' if allowOrigin == "*" { diff --git a/middleware/cors/cors_test.go b/middleware/cors/cors_test.go index 2e0b5c2244..cea2ea5d9b 100644 --- a/middleware/cors/cors_test.go +++ b/middleware/cors/cors_test.go @@ -50,7 +50,6 @@ func testDefaultOrEmptyConfig(t *testing.T, app *fiber.App) { // Test default GET response headers ctx := &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodGet) - ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost") h(ctx) @@ -70,6 +69,44 @@ func testDefaultOrEmptyConfig(t *testing.T, app *fiber.App) { utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlMaxAge))) } +func Test_CORS_AllowOrigins_Vary(t *testing.T) { + t.Parallel() + app := fiber.New() + app.Use(New( + Config{ + AllowOrigins: "http://localhost", + }, + )) + + h := app.Handler() + + // Test Vary header non-Cors request + ctx := &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(fiber.MethodGet) + h(ctx) + utils.AssertEqual(t, true, strings.Contains(string(ctx.Response.Header.Peek(fiber.HeaderVary)), fiber.HeaderOrigin), "Vary header should be set for Origin") + + // Test Vary header Cors preflight request + ctx.Request.Reset() + ctx.Response.Reset() + ctx.Request.Header.SetMethod(fiber.MethodOptions) + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) + ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost") + h(ctx) + vh := string(ctx.Response.Header.Peek(fiber.HeaderVary)) + utils.AssertEqual(t, true, strings.Contains(vh, fiber.HeaderOrigin), "Vary header should be set for Origin") + utils.AssertEqual(t, true, strings.Contains(vh, fiber.HeaderAccessControlRequestMethod), "Vary header should be set for Access-Control-Request-Method") + utils.AssertEqual(t, true, strings.Contains(vh, fiber.HeaderAccessControlRequestHeaders), "Vary header should be set for Access-Control-Request-Headers") + + // Test Vary header Cors request + ctx.Request.Reset() + ctx.Response.Reset() + ctx.Request.Header.SetMethod(fiber.MethodGet) + ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost") + h(ctx) + utils.AssertEqual(t, true, strings.Contains(string(ctx.Response.Header.Peek(fiber.HeaderVary)), fiber.HeaderOrigin), "Vary header should be set for Origin") +} + // go test -run -v Test_CORS_Wildcard func Test_CORS_Wildcard(t *testing.T) { t.Parallel() @@ -97,6 +134,10 @@ func Test_CORS_Wildcard(t *testing.T) { // Check result utils.AssertEqual(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) // Validates request is not reflecting origin in the response + vh := string(ctx.Response.Header.Peek(fiber.HeaderVary)) + utils.AssertEqual(t, true, strings.Contains(vh, fiber.HeaderOrigin), "Vary header should be set for Origin") + utils.AssertEqual(t, true, strings.Contains(vh, fiber.HeaderAccessControlRequestMethod), "Vary header should be set for Access-Control-Request-Method") + utils.AssertEqual(t, true, strings.Contains(vh, fiber.HeaderAccessControlRequestHeaders), "Vary header should be set for Access-Control-Request-Headers") utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials))) utils.AssertEqual(t, "3600", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlMaxAge))) utils.AssertEqual(t, "Authentication", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowHeaders))) @@ -105,9 +146,9 @@ func Test_CORS_Wildcard(t *testing.T) { ctx = &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost") - ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) handler(ctx) + utils.AssertEqual(t, false, strings.Contains(string(ctx.Response.Header.Peek(fiber.HeaderVary)), fiber.HeaderOrigin), "Vary header should not be set for Origin") utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials))) utils.AssertEqual(t, "X-Request-ID", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlExposeHeaders))) } @@ -147,7 +188,6 @@ func Test_CORS_Origin_AllowCredentials(t *testing.T) { // Test non OPTIONS (preflight) response headers ctx = &fasthttp.RequestCtx{} ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost") - ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.SetMethod(fiber.MethodGet) handler(ctx) @@ -466,7 +506,7 @@ func Test_CORS_Headers_BasedOnRequestType(t *testing.T) { // Get handler pointer handler := app.Handler() - t.Run("Without origin and Access-Control-Request-Method headers", func(t *testing.T) { + t.Run("Without origin", func(t *testing.T) { t.Parallel() // Make request without origin header, and without Access-Control-Request-Method for _, method := range methods { @@ -479,34 +519,6 @@ func Test_CORS_Headers_BasedOnRequestType(t *testing.T) { } }) - t.Run("With origin but without Access-Control-Request-Method header", func(t *testing.T) { - t.Parallel() - // Make request with origin header, but without Access-Control-Request-Method - for _, method := range methods { - ctx := &fasthttp.RequestCtx{} - ctx.Request.Header.SetMethod(method) - ctx.Request.SetRequestURI("https://example.com/") - ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com") - handler(ctx) - utils.AssertEqual(t, 200, ctx.Response.StatusCode(), "Status code should be 200") - utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)), "Access-Control-Allow-Origin header should not be set") - } - }) - - t.Run("Without origin but with Access-Control-Request-Method header", func(t *testing.T) { - t.Parallel() - // Make request without origin header, but with Access-Control-Request-Method - for _, method := range methods { - ctx := &fasthttp.RequestCtx{} - ctx.Request.Header.SetMethod(method) - ctx.Request.SetRequestURI("https://example.com/") - ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) - handler(ctx) - utils.AssertEqual(t, 200, ctx.Response.StatusCode(), "Status code should be 200") - utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)), "Access-Control-Allow-Origin header should not be set") - } - }) - t.Run("Preflight request with origin and Access-Control-Request-Method headers", func(t *testing.T) { t.Parallel() // Make preflight request with origin header and with Access-Control-Request-Method @@ -524,7 +536,7 @@ func Test_CORS_Headers_BasedOnRequestType(t *testing.T) { } }) - t.Run("Non-preflight request with origin and Access-Control-Request-Method headers", func(t *testing.T) { + t.Run("Non-preflight request with origin", func(t *testing.T) { t.Parallel() // Make non-preflight request with origin header and with Access-Control-Request-Method for _, method := range methods { @@ -532,7 +544,6 @@ func Test_CORS_Headers_BasedOnRequestType(t *testing.T) { ctx.Request.Header.SetMethod(method) ctx.Request.SetRequestURI("https://example.com/api/action") ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com") - ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, method) handler(ctx) utils.AssertEqual(t, 200, ctx.Response.StatusCode(), "Status code should be 200") utils.AssertEqual(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)), "Access-Control-Allow-Origin header should be set") @@ -901,7 +912,6 @@ func Benchmark_CORS_NewHandler(b *testing.B) { req.Header.SetMethod(fiber.MethodGet) req.SetRequestURI("/") req.Header.Set(fiber.HeaderOrigin, "http://localhost") - req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept") ctx.Init(req, nil, nil) @@ -942,7 +952,6 @@ func Benchmark_CORS_NewHandlerParallel(b *testing.B) { req.Header.SetMethod(fiber.MethodGet) req.SetRequestURI("/") req.Header.Set(fiber.HeaderOrigin, "http://localhost") - req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept") ctx.Init(req, nil, nil) @@ -976,7 +985,6 @@ func Benchmark_CORS_NewHandlerSingleOrigin(b *testing.B) { req.Header.SetMethod(fiber.MethodGet) req.SetRequestURI("/") req.Header.Set(fiber.HeaderOrigin, "http://example.com") - req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept") ctx.Init(req, nil, nil) @@ -1017,7 +1025,6 @@ func Benchmark_CORS_NewHandlerSingleOriginParallel(b *testing.B) { req.Header.SetMethod(fiber.MethodGet) req.SetRequestURI("/") req.Header.Set(fiber.HeaderOrigin, "http://example.com") - req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept") ctx.Init(req, nil, nil) @@ -1051,7 +1058,6 @@ func Benchmark_CORS_NewHandlerWildcard(b *testing.B) { req.Header.SetMethod(fiber.MethodGet) req.SetRequestURI("/") req.Header.Set(fiber.HeaderOrigin, "http://example.com") - req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept") ctx.Init(req, nil, nil) @@ -1092,7 +1098,6 @@ func Benchmark_CORS_NewHandlerWildcardParallel(b *testing.B) { req.Header.SetMethod(fiber.MethodGet) req.SetRequestURI("/") req.Header.Set(fiber.HeaderOrigin, "http://example.com") - req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept") ctx.Init(req, nil, nil) @@ -1122,6 +1127,7 @@ func Benchmark_CORS_NewHandlerPreflight(b *testing.B) { h := app.Handler() ctx := &fasthttp.RequestCtx{} + // Preflight request req := &fasthttp.Request{} req.Header.SetMethod(fiber.MethodOptions) req.SetRequestURI("/")