Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(middleware/cors): CORS handling #2937

Merged
merged 3 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions middleware/cors/cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,19 @@ func New(config ...Config) fiber.Handler {
// Get originHeader header
originHeader := strings.ToLower(c.Get(fiber.HeaderOrigin))

// If the request does not have Origin and Access-Control-Request-Method
// headers, the request is outside the scope of CORS
if originHeader == "" || c.Get(fiber.HeaderAccessControlRequestMethod) == "" {
// If the request does not have Origin header, the request is outside the scope of CORS
if originHeader == "" {
// See https://fetch.spec.whatwg.org/#cors-protocol-and-http-caches
// Unless all origins are allowed, we include the Vary header to cache the response correctly
if !allowAllOrigins {
c.Vary(fiber.HeaderOrigin)
}

return c.Next()
}

// If it's a preflight request and doesn't have Access-Control-Request-Method header, it's outside the scope of CORS
if c.Method() == fiber.MethodOptions && c.Get(fiber.HeaderAccessControlRequestMethod) == "" {
return c.Next()
}

Expand Down Expand Up @@ -204,13 +214,23 @@ func New(config ...Config) fiber.Handler {
// Simple request
// Ommit allowMethods and allowHeaders, only used for pre-flight requests
if c.Method() != fiber.MethodOptions {
if !allowAllOrigins {
// See https://fetch.spec.whatwg.org/#cors-protocol-and-http-caches
c.Vary(fiber.HeaderOrigin)
}
setCORSHeaders(c, allowOrigin, "", "", exposeHeaders, maxAge, cfg)
return c.Next()
}

// Preflight request
// Pre-flight request

// Response to OPTIONS request should not be cached but,
// some caching can be configured to cache such responses.
// To Avoid poisoning the cache, we include the Vary header
// of preflight responses:
c.Vary(fiber.HeaderAccessControlRequestMethod)
c.Vary(fiber.HeaderAccessControlRequestHeaders)
c.Vary(fiber.HeaderOrigin)

setCORSHeaders(c, allowOrigin, allowMethods, allowHeaders, exposeHeaders, maxAge, cfg)

Expand All @@ -221,8 +241,6 @@ func New(config ...Config) fiber.Handler {

// Function to set CORS headers
func setCORSHeaders(c *fiber.Ctx, allowOrigin, allowMethods, allowHeaders, exposeHeaders, maxAge string, cfg Config) {
c.Vary(fiber.HeaderOrigin)

if cfg.AllowCredentials {
// When AllowCredentials is true, set the Access-Control-Allow-Origin to the specific origin instead of '*'
if allowOrigin == "*" {
Expand Down
72 changes: 32 additions & 40 deletions middleware/cors/cors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ func testDefaultOrEmptyConfig(t *testing.T, app *fiber.App) {
// Test default GET response headers
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
h(ctx)

Expand All @@ -70,6 +69,33 @@ func testDefaultOrEmptyConfig(t *testing.T, app *fiber.App) {
utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlMaxAge)))
}

func Test_CORS_AllowOrigins_Vary(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(
Config{
AllowOrigins: "http://localhost",
},
))

h := app.Handler()

// Test Vary header non-Cors request
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
utils.AssertEqual(t, true, strings.Contains(string(ctx.Response.Header.Peek(fiber.HeaderVary)), fiber.HeaderOrigin), "Vary header should be set")

// Test Vary header Cors request
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
h(ctx)
utils.AssertEqual(t, true, strings.Contains(string(ctx.Response.Header.Peek(fiber.HeaderVary)), fiber.HeaderOrigin), "Vary header should be set")
}

// go test -run -v Test_CORS_Wildcard
func Test_CORS_Wildcard(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -97,6 +123,7 @@ func Test_CORS_Wildcard(t *testing.T) {

// Check result
utils.AssertEqual(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) // Validates request is not reflecting origin in the response
utils.AssertEqual(t, true, strings.Contains(string(ctx.Response.Header.Peek(fiber.HeaderVary)), fiber.HeaderOrigin), "Vary header should be set")
utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials)))
utils.AssertEqual(t, "3600", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlMaxAge)))
utils.AssertEqual(t, "Authentication", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowHeaders)))
Expand All @@ -105,9 +132,9 @@ func Test_CORS_Wildcard(t *testing.T) {
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
handler(ctx)

utils.AssertEqual(t, false, strings.Contains(string(ctx.Response.Header.Peek(fiber.HeaderVary)), fiber.HeaderOrigin), "Vary header should not be set")
utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials)))
utils.AssertEqual(t, "X-Request-ID", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlExposeHeaders)))
}
Expand Down Expand Up @@ -147,7 +174,6 @@ func Test_CORS_Origin_AllowCredentials(t *testing.T) {
// Test non OPTIONS (preflight) response headers
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.SetMethod(fiber.MethodGet)
handler(ctx)

Expand Down Expand Up @@ -466,7 +492,7 @@ func Test_CORS_Headers_BasedOnRequestType(t *testing.T) {
// Get handler pointer
handler := app.Handler()

t.Run("Without origin and Access-Control-Request-Method headers", func(t *testing.T) {
t.Run("Without origin", func(t *testing.T) {
t.Parallel()
// Make request without origin header, and without Access-Control-Request-Method
for _, method := range methods {
Expand All @@ -479,34 +505,6 @@ func Test_CORS_Headers_BasedOnRequestType(t *testing.T) {
}
})

t.Run("With origin but without Access-Control-Request-Method header", func(t *testing.T) {
t.Parallel()
// Make request with origin header, but without Access-Control-Request-Method
for _, method := range methods {
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(method)
ctx.Request.SetRequestURI("https://example.com/")
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com")
handler(ctx)
utils.AssertEqual(t, 200, ctx.Response.StatusCode(), "Status code should be 200")
utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)), "Access-Control-Allow-Origin header should not be set")
}
})

t.Run("Without origin but with Access-Control-Request-Method header", func(t *testing.T) {
t.Parallel()
// Make request without origin header, but with Access-Control-Request-Method
for _, method := range methods {
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(method)
ctx.Request.SetRequestURI("https://example.com/")
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
handler(ctx)
utils.AssertEqual(t, 200, ctx.Response.StatusCode(), "Status code should be 200")
utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)), "Access-Control-Allow-Origin header should not be set")
}
})

t.Run("Preflight request with origin and Access-Control-Request-Method headers", func(t *testing.T) {
t.Parallel()
// Make preflight request with origin header and with Access-Control-Request-Method
Expand All @@ -524,15 +522,14 @@ func Test_CORS_Headers_BasedOnRequestType(t *testing.T) {
}
})

t.Run("Non-preflight request with origin and Access-Control-Request-Method headers", func(t *testing.T) {
t.Run("Non-preflight request with origin", func(t *testing.T) {
t.Parallel()
// Make non-preflight request with origin header and with Access-Control-Request-Method
for _, method := range methods {
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(method)
ctx.Request.SetRequestURI("https://example.com/api/action")
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com")
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, method)
handler(ctx)
utils.AssertEqual(t, 200, ctx.Response.StatusCode(), "Status code should be 200")
utils.AssertEqual(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)), "Access-Control-Allow-Origin header should be set")
Expand Down Expand Up @@ -901,7 +898,6 @@ func Benchmark_CORS_NewHandler(b *testing.B) {
req.Header.SetMethod(fiber.MethodGet)
req.SetRequestURI("/")
req.Header.Set(fiber.HeaderOrigin, "http://localhost")
req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")

ctx.Init(req, nil, nil)
Expand Down Expand Up @@ -942,7 +938,6 @@ func Benchmark_CORS_NewHandlerParallel(b *testing.B) {
req.Header.SetMethod(fiber.MethodGet)
req.SetRequestURI("/")
req.Header.Set(fiber.HeaderOrigin, "http://localhost")
req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")

ctx.Init(req, nil, nil)
Expand Down Expand Up @@ -976,7 +971,6 @@ func Benchmark_CORS_NewHandlerSingleOrigin(b *testing.B) {
req.Header.SetMethod(fiber.MethodGet)
req.SetRequestURI("/")
req.Header.Set(fiber.HeaderOrigin, "http://example.com")
req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")

ctx.Init(req, nil, nil)
Expand Down Expand Up @@ -1017,7 +1011,6 @@ func Benchmark_CORS_NewHandlerSingleOriginParallel(b *testing.B) {
req.Header.SetMethod(fiber.MethodGet)
req.SetRequestURI("/")
req.Header.Set(fiber.HeaderOrigin, "http://example.com")
req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")

ctx.Init(req, nil, nil)
Expand Down Expand Up @@ -1051,7 +1044,6 @@ func Benchmark_CORS_NewHandlerWildcard(b *testing.B) {
req.Header.SetMethod(fiber.MethodGet)
req.SetRequestURI("/")
req.Header.Set(fiber.HeaderOrigin, "http://example.com")
req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")

ctx.Init(req, nil, nil)
Expand Down Expand Up @@ -1092,7 +1084,6 @@ func Benchmark_CORS_NewHandlerWildcardParallel(b *testing.B) {
req.Header.SetMethod(fiber.MethodGet)
req.SetRequestURI("/")
req.Header.Set(fiber.HeaderOrigin, "http://example.com")
req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")

ctx.Init(req, nil, nil)
Expand Down Expand Up @@ -1122,6 +1113,7 @@ func Benchmark_CORS_NewHandlerPreflight(b *testing.B) {
h := app.Handler()
ctx := &fasthttp.RequestCtx{}

// Preflight request
req := &fasthttp.Request{}
req.Header.SetMethod(fiber.MethodOptions)
req.SetRequestURI("/")
Expand Down
Loading