diff --git a/middleware/cors/cors.go b/middleware/cors/cors.go index a95986db9d..3cc6f4fa9f 100644 --- a/middleware/cors/cors.go +++ b/middleware/cors/cors.go @@ -13,7 +13,7 @@ type Config struct { // Next defines a function to skip this middleware when returned true. // // Optional. Default: nil - Next func(c fiber.Ctx) bool + Next func(c *fiber.Ctx) bool // AllowOriginsFunc defines a function that will set the 'Access-Control-Allow-Origin' // response header to the 'origin' request header when returned true. This allows for @@ -110,7 +110,7 @@ func New(config ...Config) fiber.Handler { // Validate CORS credentials configuration if cfg.AllowCredentials && cfg.AllowOrigins == "*" { - log.Panic("[CORS] Insecure setup, 'AllowCredentials' is set to true, and 'AllowOrigins' is set to a wildcard.") //nolint:revive // we want to exit the program + panic("[CORS] Insecure setup, 'AllowCredentials' is set to true, and 'AllowOrigins' is set to a wildcard.") } // allowOrigins is a slice of strings that contains the allowed origins @@ -125,7 +125,8 @@ func New(config ...Config) fiber.Handler { trimmedOrigin := strings.TrimSpace(origin) isValid, normalizedOrigin := normalizeOrigin(trimmedOrigin) if !isValid { - log.Panicf("[CORS] Invalid origin format in configuration: %s", trimmedOrigin) //nolint:revive // we want to exit the program + log.Warnf("[CORS] Invalid origin format in configuration: %s", trimmedOrigin) + panic("[CORS] Invalid origin provided in configuration") } return normalizedOrigin, true } @@ -164,15 +165,16 @@ func New(config ...Config) fiber.Handler { // Return new handler return func(c fiber.Ctx) error { // Don't execute middleware if Next returns true - if cfg.Next != nil && cfg.Next(c) { + if cfg.Next != nil && cfg.Next(&c) { return c.Next() } // Get originHeader header originHeader := strings.ToLower(c.Get(fiber.HeaderOrigin)) - // If the request does not have an Origin header, the request is outside the scope of CORS - if originHeader == "" { + // If the request does not have Origin and Access-Control-Request-Method + // headers, the request is outside the scope of CORS + if originHeader == "" || c.Get(fiber.HeaderAccessControlRequestMethod) == "" { return c.Next() } @@ -210,8 +212,9 @@ func New(config ...Config) fiber.Handler { } // Simple request + // Ommit allowMethods and allowHeaders, only used for pre-flight requests if c.Method() != fiber.MethodOptions { - setCORSHeaders(c, allowOrigin, allowMethods, allowHeaders, exposeHeaders, maxAge, cfg) + setCORSHeaders(c, allowOrigin, "", "", exposeHeaders, maxAge, cfg) return c.Next() } @@ -219,7 +222,7 @@ func New(config ...Config) fiber.Handler { c.Vary(fiber.HeaderAccessControlRequestMethod) c.Vary(fiber.HeaderAccessControlRequestHeaders) - setCORSHeaders(ctx, allowOrigin, allowMethods, allowHeaders, exposeHeaders, maxAge, cfg) + setCORSHeaders(c, allowOrigin, allowMethods, allowHeaders, exposeHeaders, maxAge, cfg) // Send 204 No Content return c.SendStatus(fiber.StatusNoContent) @@ -227,19 +230,19 @@ func New(config ...Config) fiber.Handler { } // Function to set CORS headers -func setCORSHeaders(c *fiber.Ctx, allowOrigin, allowMethods, allowHeaders, exposeHeaders, maxAge string, cfg Config) { +func setCORSHeaders(c fiber.Ctx, allowOrigin, allowMethods, allowHeaders, exposeHeaders, maxAge string, cfg Config) { c.Vary(fiber.HeaderOrigin) if cfg.AllowCredentials { // When AllowCredentials is true, set the Access-Control-Allow-Origin to the specific origin instead of '*' - if allowOrigin != "*" && allowOrigin != "" { - c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin) - c.Set(fiber.HeaderAccessControlAllowCredentials, "true") - } else if allowOrigin == "*" { + if allowOrigin == "*" { c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin) log.Warn("[CORS] 'AllowCredentials' is true, but 'AllowOrigins' cannot be set to '*'.") + } else if allowOrigin != "" { + c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin) + c.Set(fiber.HeaderAccessControlAllowCredentials, "true") } - } else if len(allowOrigin) > 0 { + } else if allowOrigin != "" { // For non-credential requests, it's safe to set to '*' or specific origins c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin) } diff --git a/middleware/cors/cors_test.go b/middleware/cors/cors_test.go index 99e84d8c52..43a58c811a 100644 --- a/middleware/cors/cors_test.go +++ b/middleware/cors/cors_test.go @@ -34,6 +34,7 @@ func Test_CORS_Negative_MaxAge(t *testing.T) { ctx := &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodOptions) + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost") app.Handler()(ctx) @@ -48,6 +49,7 @@ func testDefaultOrEmptyConfig(t *testing.T, app *fiber.App) { // Test default GET response headers ctx := &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodGet) + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost") h(ctx) @@ -58,6 +60,7 @@ func testDefaultOrEmptyConfig(t *testing.T, app *fiber.App) { // Test default OPTIONS (preflight) response headers ctx = &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodOptions) + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost") h(ctx) @@ -86,6 +89,7 @@ func Test_CORS_Wildcard(t *testing.T) { ctx.Request.SetRequestURI("/") ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost") ctx.Request.Header.SetMethod(fiber.MethodOptions) + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) // Perform request handler(ctx) @@ -100,6 +104,7 @@ func Test_CORS_Wildcard(t *testing.T) { ctx = &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost") + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) handler(ctx) require.Equal(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials))) @@ -127,6 +132,7 @@ func Test_CORS_Origin_AllowCredentials(t *testing.T) { ctx.Request.SetRequestURI("/") ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost") ctx.Request.Header.SetMethod(fiber.MethodOptions) + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) // Perform request handler(ctx) @@ -140,6 +146,7 @@ func Test_CORS_Origin_AllowCredentials(t *testing.T) { // Test non OPTIONS (preflight) response headers ctx = &fasthttp.RequestCtx{} ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost") + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.SetMethod(fiber.MethodGet) handler(ctx) @@ -225,6 +232,7 @@ func Test_CORS_Subdomain(t *testing.T) { ctx := &fasthttp.RequestCtx{} ctx.Request.SetRequestURI("/") ctx.Request.Header.SetMethod(fiber.MethodOptions) + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, "http://google.com") // Perform request @@ -239,6 +247,7 @@ func Test_CORS_Subdomain(t *testing.T) { // Make request with domain only (disallowed) ctx.Request.SetRequestURI("/") ctx.Request.Header.SetMethod(fiber.MethodOptions) + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com") handler(ctx) @@ -251,6 +260,7 @@ func Test_CORS_Subdomain(t *testing.T) { // Make request with allowed origin ctx.Request.SetRequestURI("/") ctx.Request.Header.SetMethod(fiber.MethodOptions) + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, "http://test.example.com") handler(ctx) @@ -365,6 +375,7 @@ func Test_CORS_AllowOriginScheme(t *testing.T) { ctx := &fasthttp.RequestCtx{} ctx.Request.SetRequestURI("/") ctx.Request.Header.SetMethod(fiber.MethodOptions) + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, tt.reqOrigin) handler(ctx) @@ -403,7 +414,7 @@ func Test_CORS_AllowOriginHeader_NoMatch(t *testing.T) { headerExists = true } }) - require.Equal(t, false, headerExists, "Access-Control-Allow-Origin header should not be set") + require.False(t, headerExists, "Access-Control-Allow-Origin header should not be set") } // go test -run Test_CORS_Next @@ -411,7 +422,7 @@ func Test_CORS_Next(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ - Next: func(_ fiber.Ctx) bool { + Next: func(_ *fiber.Ctx) bool { return true }, })) @@ -426,7 +437,7 @@ func Test_CORS_Headers_BasedOnRequestType(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{})) - app.Use(func(c *fiber.Ctx) error { + app.Use(func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) @@ -443,6 +454,7 @@ func Test_CORS_Headers_BasedOnRequestType(t *testing.T) { handler := app.Handler() t.Run("Without origin and Access-Control-Request-Method headers", func(t *testing.T) { + t.Parallel() // Make request without origin header, and without Access-Control-Request-Method for _, method := range methods { ctx := &fasthttp.RequestCtx{} @@ -455,6 +467,7 @@ func Test_CORS_Headers_BasedOnRequestType(t *testing.T) { }) t.Run("With origin but without Access-Control-Request-Method header", func(t *testing.T) { + t.Parallel() // Make request with origin header, but without Access-Control-Request-Method for _, method := range methods { ctx := &fasthttp.RequestCtx{} @@ -468,6 +481,7 @@ func Test_CORS_Headers_BasedOnRequestType(t *testing.T) { }) t.Run("Without origin but with Access-Control-Request-Method header", func(t *testing.T) { + t.Parallel() // Make request without origin header, but with Access-Control-Request-Method for _, method := range methods { ctx := &fasthttp.RequestCtx{} @@ -481,6 +495,7 @@ func Test_CORS_Headers_BasedOnRequestType(t *testing.T) { }) t.Run("Preflight request with origin and Access-Control-Request-Method headers", func(t *testing.T) { + t.Parallel() // Make preflight request with origin header and with Access-Control-Request-Method for _, method := range methods { ctx := &fasthttp.RequestCtx{} @@ -497,6 +512,7 @@ func Test_CORS_Headers_BasedOnRequestType(t *testing.T) { }) t.Run("Non-preflight request with origin and Access-Control-Request-Method headers", func(t *testing.T) { + t.Parallel() // Make non-preflight request with origin header and with Access-Control-Request-Method for _, method := range methods { ctx := &fasthttp.RequestCtx{} @@ -531,6 +547,7 @@ func Test_CORS_AllowOriginsAndAllowOriginsFunc(t *testing.T) { ctx := &fasthttp.RequestCtx{} ctx.Request.SetRequestURI("/") ctx.Request.Header.SetMethod(fiber.MethodOptions) + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, "http://google.com") // Perform request @@ -545,6 +562,7 @@ func Test_CORS_AllowOriginsAndAllowOriginsFunc(t *testing.T) { // Make request with allowed origin ctx.Request.SetRequestURI("/") ctx.Request.Header.SetMethod(fiber.MethodOptions) + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example-1.com") handler(ctx) @@ -557,6 +575,7 @@ func Test_CORS_AllowOriginsAndAllowOriginsFunc(t *testing.T) { // Make request with allowed origin ctx.Request.SetRequestURI("/") ctx.Request.Header.SetMethod(fiber.MethodOptions) + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example-2.com") handler(ctx) @@ -596,6 +615,7 @@ func Test_CORS_AllowOriginsFunc(t *testing.T) { // Make request with allowed origin ctx.Request.SetRequestURI("/") ctx.Request.Header.SetMethod(fiber.MethodOptions) + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example-2.com") handler(ctx) @@ -743,6 +763,7 @@ func Test_CORS_AllowOriginsAndAllowOriginsFunc_AllUseCases(t *testing.T) { ctx := &fasthttp.RequestCtx{} ctx.Request.SetRequestURI("/") ctx.Request.Header.SetMethod(fiber.MethodOptions) + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, tc.RequestOrigin) handler(ctx) @@ -833,6 +854,7 @@ func Test_CORS_AllowCredentials(t *testing.T) { ctx := &fasthttp.RequestCtx{} ctx.Request.SetRequestURI("/") ctx.Request.Header.SetMethod(fiber.MethodOptions) + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, tc.RequestOrigin) handler(ctx)