diff --git a/middleware/cors/utils_test.go b/middleware/cors/utils_test.go index adc729d05f..40b8345fa7 100644 --- a/middleware/cors/utils_test.go +++ b/middleware/cors/utils_test.go @@ -2,6 +2,8 @@ package cors import ( "testing" + + "github.com/stretchr/testify/assert" ) // go test -run -v Test_normalizeOrigin @@ -16,6 +18,9 @@ func Test_normalizeOrigin(t *testing.T) { {"http://example.com:3000", true, "http://example.com:3000"}, // Port should be preserved. {"http://example.com:3000/", true, "http://example.com:3000"}, // Trailing slash should be removed. {"http://", false, ""}, // Invalid origin should not be accepted. + {"file:///etc/passwd", false, ""}, // File scheme should not be accepted. + {"https://*example.com", false, ""}, // Wildcard domain should not be accepted. + {"http://*.example.com", false, ""}, // Wildcard subdomain should not be accepted. {"http://example.com/path", false, ""}, // Path should not be accepted. {"http://example.com?query=123", false, ""}, // Query should not be accepted. {"http://example.com#fragment", false, ""}, // Fragment should not be accepted. @@ -105,3 +110,50 @@ func Test_normalizeDomain(t *testing.T) { } } } + +func TestSubdomainMatch(t *testing.T) { + tests := []struct { + name string + sub subdomain + origin string + expected bool + }{ + { + name: "match with valid subdomain", + sub: subdomain{prefix: "https://api.", suffix: ".example.com"}, + origin: "https://api.service.example.com", + expected: true, + }, + { + name: "no match with invalid prefix", + sub: subdomain{prefix: "https://api.", suffix: ".example.com"}, + origin: "https://service.example.com", + expected: false, + }, + { + name: "no match with invalid suffix", + sub: subdomain{prefix: "https://api.", suffix: ".example.com"}, + origin: "https://api.example.org", + expected: false, + }, + { + name: "no match with empty origin", + sub: subdomain{prefix: "https://api.", suffix: ".example.com"}, + origin: "", + expected: false, + }, + { + name: "partial match not considered a match", + sub: subdomain{prefix: "https://api.", suffix: ".example.com"}, + origin: "https://api.example.com", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.sub.match(tt.origin) + assert.Equal(t, tt.expected, got, "subdomain.match()") + }) + } +}