Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(middleware/cors): Validation of multiple Origins #2883

Merged
merged 8 commits into from
Mar 1, 2024
26 changes: 14 additions & 12 deletions middleware/cors/cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,24 +113,28 @@ func New(config ...Config) fiber.Handler {
panic("[CORS] Insecure setup, 'AllowCredentials' is set to true, and 'AllowOrigins' is set to a wildcard.")
}

// Validate and normalize static AllowOrigins if not using AllowOriginsFunc
if cfg.AllowOriginsFunc == nil && cfg.AllowOrigins != "" && cfg.AllowOrigins != "*" {
validatedOrigins := []string{}
// allowOrigins is a slice of strings that contains the allowed origins
// that are defined in the 'AllowOrigins' configuration.
var allowOrigins []string

// Validate and normalize static AllowOrigins
if cfg.AllowOrigins != "" && cfg.AllowOrigins != "*" {
for _, origin := range strings.Split(cfg.AllowOrigins, ",") {
origin = strings.TrimSpace(origin)
isValid, normalizedOrigin := normalizeOrigin(origin)
if isValid {
validatedOrigins = append(validatedOrigins, normalizedOrigin)
allowOrigins = append(allowOrigins, normalizedOrigin)
} else {
log.Warnf("[CORS] Invalid origin format in configuration: %s", origin)
panic("[CORS] Invalid origin provided in configuration")
}
}
cfg.AllowOrigins = strings.Join(validatedOrigins, ",")
} else {
// If AllowOrigins is set to a wildcard or not set,
// set the allowOrigins to a slice with a single element
allowOrigins = []string{cfg.AllowOrigins}
}

// Convert string to slice
allowOrigins := strings.Split(strings.ReplaceAll(cfg.AllowOrigins, " ", ""), ",")

// Strip white spaces
allowMethods := strings.ReplaceAll(cfg.AllowMethods, " ", "")
allowHeaders := strings.ReplaceAll(cfg.AllowHeaders, " ", "")
Expand Down Expand Up @@ -165,10 +169,8 @@ func New(config ...Config) fiber.Handler {
// Run AllowOriginsFunc if the logic for
// handling the value in 'AllowOrigins' does
// not result in allowOrigin being set.
if allowOrigin == "" && cfg.AllowOriginsFunc != nil {
if cfg.AllowOriginsFunc(originHeader) {
allowOrigin = originHeader
}
if allowOrigin == "" && cfg.AllowOriginsFunc != nil && cfg.AllowOriginsFunc(originHeader) {
allowOrigin = originHeader
}

// Simple request
Expand Down
75 changes: 75 additions & 0 deletions middleware/cors/cors_test.go
sixcolors marked this conversation as resolved.
Show resolved Hide resolved
sixcolors marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,33 @@ func Test_CORS_AllowOriginsAndAllowOriginsFunc_AllUseCases(t *testing.T) {
RequestOrigin: "http://aaa.com",
ResponseOrigin: "http://aaa.com",
},
{
Name: "AllowOriginsDefined/AllowOriginsFuncUndefined/MultipleOrigins/NoWhitespace/OriginAllowed",
Config: Config{
AllowOrigins: "http://aaa.com,http://bbb.com",
AllowOriginsFunc: nil,
},
RequestOrigin: "http://bbb.com",
ResponseOrigin: "http://bbb.com",
},
{
Name: "AllowOriginsDefined/AllowOriginsFuncUndefined/MultipleOrigins/NoWhitespace/OriginNotAllowed",
Config: Config{
AllowOrigins: "http://aaa.com,http://bbb.com",
AllowOriginsFunc: nil,
},
RequestOrigin: "http://ccc.com",
ResponseOrigin: "",
},
{
Name: "AllowOriginsDefined/AllowOriginsFuncUndefined/MultipleOrigins/Whitespace/OriginAllowed",
Config: Config{
AllowOrigins: "http://aaa.com, http://bbb.com",
AllowOriginsFunc: nil,
},
RequestOrigin: "http://aaa.com",
ResponseOrigin: "http://aaa.com",
},
{
Name: "AllowOriginsDefined/AllowOriginsFuncUndefined/OriginNotAllowed",
Config: Config{
Expand Down Expand Up @@ -647,3 +674,51 @@ func Test_CORS_AllowCredentials(t *testing.T) {
})
}
}

func Benchmark_CORS_NewHandler(b *testing.B) {
sixcolors marked this conversation as resolved.
Show resolved Hide resolved
app := fiber.New()
app.Use(New(Config{
AllowOrigins: "http://localhost,http://example.com",
sixcolors marked this conversation as resolved.
Show resolved Hide resolved
AllowMethods: "GET,POST,PUT,DELETE",
AllowHeaders: "Origin,Content-Type,Accept",
AllowCredentials: true,
MaxAge: 600,
}))

req := httptest.NewRequest(fiber.MethodGet, "/", nil)
req.Header.Set(fiber.HeaderOrigin, "http://localhost")

b.ResetTimer()
for i := 0; i < b.N; i++ {
resp, err := app.Test(req)
sixcolors marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
b.Fatalf("Failed to perform request: %v", err)
}
_ = resp
}
}

func Benchmark_CORS_NewHandlerPreflight(b *testing.B) {
app := fiber.New()
app.Use(New(Config{
AllowOrigins: "http://localhost,http://example.com",
AllowMethods: "GET,POST,PUT,DELETE",
AllowHeaders: "Origin,Content-Type,Accept",
AllowCredentials: true,
MaxAge: 600,
}))

req := httptest.NewRequest(fiber.MethodOptions, "/", nil)
req.Header.Set(fiber.HeaderOrigin, "http://localhost")
req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodPost)
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")

b.ResetTimer()
for i := 0; i < b.N; i++ {
resp, err := app.Test(req)
sixcolors marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
b.Fatalf("Failed to perform request: %v", err)
}
_ = resp
}
}
Loading