diff --git a/middleware/cors/cors.go b/middleware/cors/cors.go index 2ca3767d1f..3accc2f1dd 100644 --- a/middleware/cors/cors.go +++ b/middleware/cors/cors.go @@ -113,24 +113,32 @@ func New(config ...Config) fiber.Handler { panic("[CORS] Insecure setup, 'AllowCredentials' is set to true, and 'AllowOrigins' is set to a wildcard.") } - // Validate and normalize static AllowOrigins if not using AllowOriginsFunc - if cfg.AllowOriginsFunc == nil && cfg.AllowOrigins != "" && cfg.AllowOrigins != "*" { - validatedOrigins := []string{} - for _, origin := range strings.Split(cfg.AllowOrigins, ",") { - isValid, normalizedOrigin := normalizeOrigin(origin) + // allowOrigins is a slice of strings that contains the allowed origins + // defined in the 'AllowOrigins' configuration. + var allowOrigins []string + + // Validate and normalize static AllowOrigins + if cfg.AllowOrigins != "" && cfg.AllowOrigins != "*" { + origins := strings.Split(cfg.AllowOrigins, ",") + allowOrigins = make([]string, len(origins)) + + for i, origin := range origins { + trimmedOrigin := strings.TrimSpace(origin) + isValid, normalizedOrigin := normalizeOrigin(trimmedOrigin) + if isValid { - validatedOrigins = append(validatedOrigins, normalizedOrigin) + allowOrigins[i] = normalizedOrigin } else { - log.Warnf("[CORS] Invalid origin format in configuration: %s", origin) + log.Warnf("[CORS] Invalid origin format in configuration: %s", trimmedOrigin) panic("[CORS] Invalid origin provided in configuration") } } - cfg.AllowOrigins = strings.Join(validatedOrigins, ",") + } else { + // If AllowOrigins is set to a wildcard or not set, + // set allowOrigins to a slice with a single element + allowOrigins = []string{cfg.AllowOrigins} } - // Convert string to slice - allowOrigins := strings.Split(strings.ReplaceAll(cfg.AllowOrigins, " ", ""), ",") - // Strip white spaces allowMethods := strings.ReplaceAll(cfg.AllowMethods, " ", "") allowHeaders := strings.ReplaceAll(cfg.AllowHeaders, " ", "") @@ -165,10 +173,8 @@ func New(config ...Config) fiber.Handler { // Run AllowOriginsFunc if the logic for // handling the value in 'AllowOrigins' does // not result in allowOrigin being set. - if allowOrigin == "" && cfg.AllowOriginsFunc != nil { - if cfg.AllowOriginsFunc(originHeader) { - allowOrigin = originHeader - } + if allowOrigin == "" && cfg.AllowOriginsFunc != nil && cfg.AllowOriginsFunc(originHeader) { + allowOrigin = originHeader } // Simple request diff --git a/middleware/cors/cors_test.go b/middleware/cors/cors_test.go index 9fc2852556..23692c3f84 100644 --- a/middleware/cors/cors_test.go +++ b/middleware/cors/cors_test.go @@ -307,6 +307,21 @@ func Test_CORS_AllowOriginScheme(t *testing.T) { reqOrigin: "http://ccc.bbb.example.com", shouldAllowOrigin: false, }, + { + pattern: "http://domain-1.com, http://example.com", + reqOrigin: "http://example.com", + shouldAllowOrigin: true, + }, + { + pattern: "http://domain-1.com, http://example.com", + reqOrigin: "http://domain-2.com", + shouldAllowOrigin: false, + }, + { + pattern: "http://domain-1.com,http://example.com", + reqOrigin: "http://domain-1.com", + shouldAllowOrigin: true, + }, } for _, tt := range tests { @@ -452,6 +467,33 @@ func Test_CORS_AllowOriginsAndAllowOriginsFunc_AllUseCases(t *testing.T) { RequestOrigin: "http://aaa.com", ResponseOrigin: "http://aaa.com", }, + { + Name: "AllowOriginsDefined/AllowOriginsFuncUndefined/MultipleOrigins/NoWhitespace/OriginAllowed", + Config: Config{ + AllowOrigins: "http://aaa.com,http://bbb.com", + AllowOriginsFunc: nil, + }, + RequestOrigin: "http://bbb.com", + ResponseOrigin: "http://bbb.com", + }, + { + Name: "AllowOriginsDefined/AllowOriginsFuncUndefined/MultipleOrigins/NoWhitespace/OriginNotAllowed", + Config: Config{ + AllowOrigins: "http://aaa.com,http://bbb.com", + AllowOriginsFunc: nil, + }, + RequestOrigin: "http://ccc.com", + ResponseOrigin: "", + }, + { + Name: "AllowOriginsDefined/AllowOriginsFuncUndefined/MultipleOrigins/Whitespace/OriginAllowed", + Config: Config{ + AllowOrigins: "http://aaa.com, http://bbb.com", + AllowOriginsFunc: nil, + }, + RequestOrigin: "http://aaa.com", + ResponseOrigin: "http://aaa.com", + }, { Name: "AllowOriginsDefined/AllowOriginsFuncUndefined/OriginNotAllowed", Config: Config{ @@ -647,3 +689,453 @@ func Test_CORS_AllowCredentials(t *testing.T) { }) } } + +// go test -v -run=^$ -bench=Benchmark_CORS_NewHandler -benchmem -count=4 +func Benchmark_CORS_NewHandler(b *testing.B) { + app := fiber.New() + c := New(Config{ + AllowOrigins: "http://localhost,http://example.com", + AllowMethods: "GET,POST,PUT,DELETE", + AllowHeaders: "Origin,Content-Type,Accept", + AllowCredentials: true, + MaxAge: 600, + }) + + app.Use(c) + app.Use(func(c *fiber.Ctx) error { + return c.SendStatus(fiber.StatusOK) + }) + + h := app.Handler() + ctx := &fasthttp.RequestCtx{} + + req := &fasthttp.Request{} + req.Header.SetMethod(fiber.MethodGet) + req.SetRequestURI("/") + req.Header.Set(fiber.HeaderOrigin, "http://localhost") + req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) + req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept") + + ctx.Init(req, nil, nil) + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + h(ctx) + } +} + +// go test -v -run=^$ -bench=Benchmark_CORS_NewHandlerParallel -benchmem -count=4 +func Benchmark_CORS_NewHandlerParallel(b *testing.B) { + app := fiber.New() + c := New(Config{ + AllowOrigins: "http://localhost,http://example.com", + AllowMethods: "GET,POST,PUT,DELETE", + AllowHeaders: "Origin,Content-Type,Accept", + AllowCredentials: true, + MaxAge: 600, + }) + + app.Use(c) + app.Use(func(c *fiber.Ctx) error { + return c.SendStatus(fiber.StatusOK) + }) + + h := app.Handler() + + b.ReportAllocs() + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + ctx := &fasthttp.RequestCtx{} + + req := &fasthttp.Request{} + req.Header.SetMethod(fiber.MethodGet) + req.SetRequestURI("/") + req.Header.Set(fiber.HeaderOrigin, "http://localhost") + req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) + req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept") + + ctx.Init(req, nil, nil) + + for pb.Next() { + h(ctx) + } + }) +} + +// go test -v -run=^$ -bench=Benchmark_CORS_NewHandlerSingleOrigin -benchmem -count=4 +func Benchmark_CORS_NewHandlerSingleOrigin(b *testing.B) { + app := fiber.New() + c := New(Config{ + AllowOrigins: "http://example.com", + AllowMethods: "GET,POST,PUT,DELETE", + AllowHeaders: "Origin,Content-Type,Accept", + AllowCredentials: true, + MaxAge: 600, + }) + + app.Use(c) + app.Use(func(c *fiber.Ctx) error { + return c.SendStatus(fiber.StatusOK) + }) + + h := app.Handler() + ctx := &fasthttp.RequestCtx{} + + req := &fasthttp.Request{} + req.Header.SetMethod(fiber.MethodGet) + req.SetRequestURI("/") + req.Header.Set(fiber.HeaderOrigin, "http://example.com") + req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) + req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept") + + ctx.Init(req, nil, nil) + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + h(ctx) + } +} + +// go test -v -run=^$ -bench=Benchmark_CORS_NewHandlerSingleOriginParallel -benchmem -count=4 +func Benchmark_CORS_NewHandlerSingleOriginParallel(b *testing.B) { + app := fiber.New() + c := New(Config{ + AllowOrigins: "http://example.com", + AllowMethods: "GET,POST,PUT,DELETE", + AllowHeaders: "Origin,Content-Type,Accept", + AllowCredentials: true, + MaxAge: 600, + }) + + app.Use(c) + app.Use(func(c *fiber.Ctx) error { + return c.SendStatus(fiber.StatusOK) + }) + + h := app.Handler() + + b.ReportAllocs() + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + ctx := &fasthttp.RequestCtx{} + + req := &fasthttp.Request{} + req.Header.SetMethod(fiber.MethodGet) + req.SetRequestURI("/") + req.Header.Set(fiber.HeaderOrigin, "http://example.com") + req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) + req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept") + + ctx.Init(req, nil, nil) + + for pb.Next() { + h(ctx) + } + }) +} + +// go test -v -run=^$ -bench=Benchmark_CORS_NewHandlerWildcard -benchmem -count=4 +func Benchmark_CORS_NewHandlerWildcard(b *testing.B) { + app := fiber.New() + c := New(Config{ + AllowOrigins: "*", + AllowMethods: "GET,POST,PUT,DELETE", + AllowHeaders: "Origin,Content-Type,Accept", + AllowCredentials: false, + MaxAge: 600, + }) + + app.Use(c) + app.Use(func(c *fiber.Ctx) error { + return c.SendStatus(fiber.StatusOK) + }) + + h := app.Handler() + ctx := &fasthttp.RequestCtx{} + + req := &fasthttp.Request{} + req.Header.SetMethod(fiber.MethodGet) + req.SetRequestURI("/") + req.Header.Set(fiber.HeaderOrigin, "http://example.com") + req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) + req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept") + + ctx.Init(req, nil, nil) + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + h(ctx) + } +} + +// go test -v -run=^$ -bench=Benchmark_CORS_NewHandlerWildcardParallel -benchmem -count=4 +func Benchmark_CORS_NewHandlerWildcardParallel(b *testing.B) { + app := fiber.New() + c := New(Config{ + AllowOrigins: "*", + AllowMethods: "GET,POST,PUT,DELETE", + AllowHeaders: "Origin,Content-Type,Accept", + AllowCredentials: false, + MaxAge: 600, + }) + + app.Use(c) + app.Use(func(c *fiber.Ctx) error { + return c.SendStatus(fiber.StatusOK) + }) + + h := app.Handler() + + b.ReportAllocs() + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + ctx := &fasthttp.RequestCtx{} + + req := &fasthttp.Request{} + req.Header.SetMethod(fiber.MethodGet) + req.SetRequestURI("/") + req.Header.Set(fiber.HeaderOrigin, "http://example.com") + req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) + req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept") + + ctx.Init(req, nil, nil) + + for pb.Next() { + h(ctx) + } + }) +} + +// go test -v -run=^$ -bench=Benchmark_CORS_NewHandlerPreflight -benchmem -count=4 +func Benchmark_CORS_NewHandlerPreflight(b *testing.B) { + app := fiber.New() + c := New(Config{ + AllowOrigins: "http://localhost,http://example.com", + AllowMethods: "GET,POST,PUT,DELETE", + AllowHeaders: "Origin,Content-Type,Accept", + AllowCredentials: true, + MaxAge: 600, + }) + + app.Use(c) + app.Use(func(c *fiber.Ctx) error { + return c.SendStatus(fiber.StatusOK) + }) + + h := app.Handler() + ctx := &fasthttp.RequestCtx{} + + req := &fasthttp.Request{} + req.Header.SetMethod(fiber.MethodOptions) + req.SetRequestURI("/") + req.Header.Set(fiber.HeaderOrigin, "http://example.com") + req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodPost) + req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept") + + ctx.Init(req, nil, nil) + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + h(ctx) + } +} + +// go test -v -run=^$ -bench=Benchmark_CORS_NewHandlerPreflightParallel -benchmem -count=4 +func Benchmark_CORS_NewHandlerPreflightParallel(b *testing.B) { + app := fiber.New() + c := New(Config{ + AllowOrigins: "http://localhost,http://example.com", + AllowMethods: "GET,POST,PUT,DELETE", + AllowHeaders: "Origin,Content-Type,Accept", + AllowCredentials: true, + MaxAge: 600, + }) + + app.Use(c) + app.Use(func(c *fiber.Ctx) error { + return c.SendStatus(fiber.StatusOK) + }) + + h := app.Handler() + + b.ReportAllocs() + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + ctx := &fasthttp.RequestCtx{} + + req := &fasthttp.Request{} + req.Header.SetMethod(fiber.MethodOptions) + req.SetRequestURI("/") + req.Header.Set(fiber.HeaderOrigin, "http://example.com") + req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodPost) + req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept") + + ctx.Init(req, nil, nil) + + for pb.Next() { + h(ctx) + } + }) +} + +// go test -v -run=^$ -bench=Benchmark_CORS_NewHandlerPreflightSingleOrigin -benchmem -count=4 +func Benchmark_CORS_NewHandlerPreflightSingleOrigin(b *testing.B) { + app := fiber.New() + c := New(Config{ + AllowOrigins: "http://example.com", + AllowMethods: "GET,POST,PUT,DELETE", + AllowHeaders: "Origin,Content-Type,Accept", + AllowCredentials: true, + MaxAge: 600, + }) + + app.Use(c) + app.Use(func(c *fiber.Ctx) error { + return c.SendStatus(fiber.StatusOK) + }) + + h := app.Handler() + ctx := &fasthttp.RequestCtx{} + + req := &fasthttp.Request{} + req.Header.SetMethod(fiber.MethodOptions) + req.SetRequestURI("/") + req.Header.Set(fiber.HeaderOrigin, "http://example.com") + req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodPost) + req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept") + + ctx.Init(req, nil, nil) + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + h(ctx) + } +} + +// go test -v -run=^$ -bench=Benchmark_CORS_NewHandlerPreflightSingleOriginParallel -benchmem -count=4 +func Benchmark_CORS_NewHandlerPreflightSingleOriginParallel(b *testing.B) { + app := fiber.New() + c := New(Config{ + AllowOrigins: "http://example.com", + AllowMethods: "GET,POST,PUT,DELETE", + AllowHeaders: "Origin,Content-Type,Accept", + AllowCredentials: true, + MaxAge: 600, + }) + + app.Use(c) + app.Use(func(c *fiber.Ctx) error { + return c.SendStatus(fiber.StatusOK) + }) + + h := app.Handler() + + b.ReportAllocs() + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + ctx := &fasthttp.RequestCtx{} + + req := &fasthttp.Request{} + req.Header.SetMethod(fiber.MethodOptions) + req.SetRequestURI("/") + req.Header.Set(fiber.HeaderOrigin, "http://example.com") + req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodPost) + req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept") + + ctx.Init(req, nil, nil) + + for pb.Next() { + h(ctx) + } + }) +} + +// go test -v -run=^$ -bench=Benchmark_CORS_NewHandlerPreflightWildcard -benchmem -count=4 +func Benchmark_CORS_NewHandlerPreflightWildcard(b *testing.B) { + app := fiber.New() + c := New(Config{ + AllowOrigins: "*", + AllowMethods: "GET,POST,PUT,DELETE", + AllowHeaders: "Origin,Content-Type,Accept", + AllowCredentials: false, + MaxAge: 600, + }) + + app.Use(c) + app.Use(func(c *fiber.Ctx) error { + return c.SendStatus(fiber.StatusOK) + }) + + h := app.Handler() + ctx := &fasthttp.RequestCtx{} + + req := &fasthttp.Request{} + req.Header.SetMethod(fiber.MethodOptions) + req.SetRequestURI("/") + req.Header.Set(fiber.HeaderOrigin, "http://example.com") + req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodPost) + req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept") + + ctx.Init(req, nil, nil) + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + h(ctx) + } +} + +// go test -v -run=^$ -bench=Benchmark_CORS_NewHandlerPreflightWildcardParallel -benchmem -count=4 +func Benchmark_CORS_NewHandlerPreflightWildcardParallel(b *testing.B) { + app := fiber.New() + c := New(Config{ + AllowOrigins: "*", + AllowMethods: "GET,POST,PUT,DELETE", + AllowHeaders: "Origin,Content-Type,Accept", + AllowCredentials: false, + MaxAge: 600, + }) + + app.Use(c) + app.Use(func(c *fiber.Ctx) error { + return c.SendStatus(fiber.StatusOK) + }) + + h := app.Handler() + + b.ReportAllocs() + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + ctx := &fasthttp.RequestCtx{} + + req := &fasthttp.Request{} + req.Header.SetMethod(fiber.MethodOptions) + req.SetRequestURI("/") + req.Header.Set(fiber.HeaderOrigin, "http://example.com") + req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodPost) + req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept") + + ctx.Init(req, nil, nil) + + for pb.Next() { + h(ctx) + } + }) +}