From 080e86e16813e410717c9806ea5a4ea9295724d1 Mon Sep 17 00:00:00 2001 From: Olivier Poitrey Date: Tue, 5 Sep 2023 16:14:03 +0200 Subject: [PATCH] Remove most allocations --- cors.go | 93 ++++++++++++++++++++++++++++++------------------- cors_test.go | 19 ++++++----- utils.go | 95 ++++++++++++++++++++++++++------------------------- utils_test.go | 74 +++++++++++++++++++-------------------- 4 files changed, 153 insertions(+), 128 deletions(-) diff --git a/cors.go b/cors.go index 6e4e80a..c330a93 100644 --- a/cors.go +++ b/cors.go @@ -28,6 +28,10 @@ import ( "strings" ) +var headerVaryOrigin = []string{"Origin"} +var headerOriginAll = []string{"*"} +var headerTrue = []string{"true"} + // Options is a configuration container to setup the CORS middleware. type Options struct { // AllowedOrigins is a list of origins a cross-domain request can be executed from. @@ -108,9 +112,10 @@ type Cors struct { allowedHeaders []string // Normalized list of allowed methods allowedMethods []string - // Normalized list of exposed headers + // Pre-computed normalized list of exposed headers exposedHeaders []string - maxAge int + // Pre-computed maxAge header value + maxAge []string // Set to true when allowed origins contains a "*" allowedOriginsAll bool // Set to true when allowed headers contains a "*" @@ -120,15 +125,14 @@ type Cors struct { allowCredentials bool allowPrivateNetwork bool optionPassthrough bool + preflightVary []string } // New creates a new Cors handler with the provided options. func New(options Options) *Cors { c := &Cors{ - exposedHeaders: convert(options.ExposedHeaders, http.CanonicalHeaderKey), allowCredentials: options.AllowCredentials, allowPrivateNetwork: options.AllowPrivateNetwork, - maxAge: options.MaxAge, optionPassthrough: options.OptionsPassthrough, Log: options.Logger, } @@ -211,6 +215,23 @@ func New(options Options) *Cors { c.optionsSuccessStatus = options.OptionsSuccessStatus } + // Pre-compute exposed headers header value + c.exposedHeaders = []string{strings.Join(convert(options.ExposedHeaders, http.CanonicalHeaderKey), ", ")} + + // Pre-compute prefight Vary header to save allocations + if c.allowPrivateNetwork { + c.preflightVary = []string{"Origin, Access-Control-Request-Method, Access-Control-Request-Headers, Access-Control-Request-Private-Network"} + } else { + c.preflightVary = []string{"Origin, Access-Control-Request-Method, Access-Control-Request-Headers"} + } + + // Precompute max-age + if options.MaxAge > 0 { + c.maxAge = []string{strconv.Itoa(options.MaxAge)} + } else if options.MaxAge < 0 { + c.maxAge = []string{"0"} + } + return c } @@ -307,13 +328,11 @@ func (c *Cors) handlePreflight(w http.ResponseWriter, r *http.Request) { // Always set Vary headers // see https://github.com/rs/cors/issues/10, // https://github.com/rs/cors/commit/dbdca4d95feaa7511a46e6f1efb3b3aa505bc43f#commitcomment-12352001 - headers.Add("Vary", "Origin") - headers.Add("Vary", "Access-Control-Request-Method") - headers.Add("Vary", "Access-Control-Request-Headers") - if c.allowPrivateNetwork { - headers.Add("Vary", "Access-Control-Request-Private-Network") + if vary, found := headers["Vary"]; found { + headers["Vary"] = append(vary, c.preflightVary[0]) + } else { + headers["Vary"] = c.preflightVary } - allowed, additionalVaryHeaders := c.isOriginAllowed(r, origin) if len(additionalVaryHeaders) > 0 { headers.Add("Vary", strings.Join(convert(additionalVaryHeaders, http.CanonicalHeaderKey), ", ")) @@ -333,42 +352,41 @@ func (c *Cors) handlePreflight(w http.ResponseWriter, r *http.Request) { c.logf(" Preflight aborted: method '%s' not allowed", reqMethod) return } - // Amazon API Gateway is sometimes feeding multiple values for - // Access-Control-Request-Headers in a way where r.Header.Values() picks - // them all up, but r.Header.Get() does not. - // I suspect it is something like this: https://stackoverflow.com/a/4371395 - reqHeaderList := strings.Join(r.Header.Values("Access-Control-Request-Headers"), ",") - reqHeaders := parseHeaderList(reqHeaderList) + reqHeadersRaw := r.Header["Access-Control-Request-Headers"] + reqHeaders, reqHeadersEdited := convertDidCopy(splitHeaderValues(reqHeadersRaw), http.CanonicalHeaderKey) if !c.areHeadersAllowed(reqHeaders) { c.logf(" Preflight aborted: headers '%v' not allowed", reqHeaders) return } if c.allowedOriginsAll { - headers.Set("Access-Control-Allow-Origin", "*") + headers["Access-Control-Allow-Origin"] = headerOriginAll } else { - headers.Set("Access-Control-Allow-Origin", origin) + headers["Access-Control-Allow-Origin"] = r.Header["Origin"] } // Spec says: Since the list of methods can be unbounded, simply returning the method indicated // by Access-Control-Request-Method (if supported) can be enough - headers.Set("Access-Control-Allow-Methods", reqMethod) + headers["Access-Control-Allow-Methods"] = r.Header["Access-Control-Request-Method"] if len(reqHeaders) > 0 { - // Spec says: Since the list of headers can be unbounded, simply returning supported headers // from Access-Control-Request-Headers can be enough - headers.Set("Access-Control-Allow-Headers", strings.Join(reqHeaders, ", ")) + if reqHeadersEdited || len(reqHeaders) != len(reqHeadersRaw) { + headers.Set("Access-Control-Allow-Headers", strings.Join(reqHeaders, ", ")) + } else { + headers["Access-Control-Allow-Headers"] = reqHeadersRaw + } } if c.allowCredentials { - headers.Set("Access-Control-Allow-Credentials", "true") + headers["Access-Control-Allow-Credentials"] = headerTrue } if c.allowPrivateNetwork && r.Header.Get("Access-Control-Request-Private-Network") == "true" { - headers.Set("Access-Control-Allow-Private-Network", "true") + headers["Access-Control-Allow-Private-Network"] = headerTrue + } + if len(c.maxAge) > 0 { + headers["Access-Control-Max-Age"] = c.maxAge } - if c.maxAge > 0 { - headers.Set("Access-Control-Max-Age", strconv.Itoa(c.maxAge)) - } else if c.maxAge < 0 { - headers.Set("Access-Control-Max-Age", "0") + if c.Log != nil { + c.logf(" Preflight response headers: %v", headers) } - c.logf(" Preflight response headers: %v", headers) } // handleActualRequest handles simple cross-origin requests, actual request or redirects @@ -379,7 +397,11 @@ func (c *Cors) handleActualRequest(w http.ResponseWriter, r *http.Request) { allowed, additionalVaryHeaders := c.isOriginAllowed(r, origin) // Always set Vary, see https://github.com/rs/cors/issues/10 - headers.Add("Vary", "Origin") + if vary, found := headers["Vary"]; found { + headers["Vary"] = append(vary, headerVaryOrigin[0]) + } else { + headers["Vary"] = headerVaryOrigin + } if len(additionalVaryHeaders) > 0 { headers.Add("Vary", strings.Join(convert(additionalVaryHeaders, http.CanonicalHeaderKey), ", ")) } @@ -401,17 +423,19 @@ func (c *Cors) handleActualRequest(w http.ResponseWriter, r *http.Request) { return } if c.allowedOriginsAll { - headers.Set("Access-Control-Allow-Origin", "*") + headers["Access-Control-Allow-Origin"] = headerOriginAll } else { - headers.Set("Access-Control-Allow-Origin", origin) + headers["Access-Control-Allow-Origin"] = r.Header["Origin"] } if len(c.exposedHeaders) > 0 { - headers.Set("Access-Control-Expose-Headers", strings.Join(c.exposedHeaders, ", ")) + headers["Access-Control-Expose-Headers"] = c.exposedHeaders } if c.allowCredentials { - headers.Set("Access-Control-Allow-Credentials", "true") + headers["Access-Control-Allow-Credentials"] = headerTrue + } + if c.Log != nil { + c.logf(" Actual response added headers: %v", headers) } - c.logf(" Actual response added headers: %v", headers) } // convenience method. checks if a logger is set. @@ -477,7 +501,6 @@ func (c *Cors) areHeadersAllowed(requestedHeaders []string) bool { return true } for _, header := range requestedHeaders { - header = http.CanonicalHeaderKey(header) found := false for _, h := range c.allowedHeaders { if h == header { diff --git a/cors_test.go b/cors_test.go index a46233c..473bc35 100644 --- a/cors_test.go +++ b/cors_test.go @@ -10,8 +10,9 @@ import ( "testing" ) +var testResponse = []byte("bar") var testHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("bar")) + _, _ = w.Write(testResponse) }) var allHeaders = []string{ @@ -26,6 +27,7 @@ var allHeaders = []string{ } func assertHeaders(t *testing.T, resHeaders http.Header, expHeaders map[string]string) { + t.Helper() for _, name := range allHeaders { got := strings.Join(resHeaders[name], ", ") want := expHeaders[name] @@ -36,6 +38,7 @@ func assertHeaders(t *testing.T, resHeaders http.Header, expHeaders map[string]s } func assertResponse(t *testing.T, res *httptest.ResponseRecorder, responseCode int) { + t.Helper() if responseCode != res.Code { t.Errorf("assertResponse: expected response code to be %d but got %d. ", responseCode, res.Code) } @@ -721,37 +724,37 @@ func TestCorsAreHeadersAllowed(t *testing.T) { { name: "nil allowedHeaders", allowedHeaders: nil, - requestedHeaders: parseHeaderList("X-PINGOTHER, Content-Type"), + requestedHeaders: []string{"X-PINGOTHER, Content-Type"}, want: false, }, { name: "star allowedHeaders", allowedHeaders: []string{"*"}, - requestedHeaders: parseHeaderList("X-PINGOTHER, Content-Type"), + requestedHeaders: []string{"X-PINGOTHER, Content-Type"}, want: true, }, { name: "empty reqHeader", allowedHeaders: nil, - requestedHeaders: parseHeaderList(""), + requestedHeaders: []string{}, want: true, }, { name: "match allowedHeaders", allowedHeaders: []string{"Content-Type", "X-PINGOTHER", "X-APP-KEY"}, - requestedHeaders: parseHeaderList("X-PINGOTHER, Content-Type"), + requestedHeaders: []string{"X-PINGOTHER, Content-Type"}, want: true, }, { name: "not matched allowedHeaders", allowedHeaders: []string{"X-PINGOTHER"}, - requestedHeaders: parseHeaderList("X-API-KEY, Content-Type"), + requestedHeaders: []string{"X-API-KEY, Content-Type"}, want: false, }, { name: "allowedHeaders should be a superset of requestedHeaders", allowedHeaders: []string{"X-PINGOTHER"}, - requestedHeaders: parseHeaderList("X-PINGOTHER, Content-Type"), + requestedHeaders: []string{"X-PINGOTHER, Content-Type"}, want: false, }, } @@ -761,7 +764,7 @@ func TestCorsAreHeadersAllowed(t *testing.T) { t.Run(tt.name, func(t *testing.T) { c := New(Options{AllowedHeaders: tt.allowedHeaders}) - have := c.areHeadersAllowed(tt.requestedHeaders) + have := c.areHeadersAllowed(convert(splitHeaderValues(tt.requestedHeaders), http.CanonicalHeaderKey)) if have != tt.want { t.Errorf("Cors.areHeadersAllowed() have: %t want: %t", have, tt.want) } diff --git a/utils.go b/utils.go index 6bb120c..ca9983d 100644 --- a/utils.go +++ b/utils.go @@ -1,8 +1,8 @@ package cors -import "strings" - -const toLower = 'a' - 'A' +import ( + "strings" +) type converter func(string) string @@ -15,57 +15,58 @@ func (w wildcard) match(s string) bool { return len(s) >= len(w.prefix)+len(w.suffix) && strings.HasPrefix(s, w.prefix) && strings.HasSuffix(s, w.suffix) } -// convert converts a list of string using the passed converter function -func convert(s []string, c converter) []string { - out := []string{} - for _, i := range s { - out = append(out, c(i)) - } - return out -} - -// parseHeaderList tokenize + normalize a string containing a list of headers -func parseHeaderList(headerList string) []string { - l := len(headerList) - h := make([]byte, 0, l) - upper := true - // Estimate the number headers in order to allocate the right splice size - t := 0 - for i := 0; i < l; i++ { - if headerList[i] == ',' { - t++ - } - } - headers := make([]string, 0, t) - for i := 0; i < l; i++ { - b := headerList[i] - switch { - case b >= 'a' && b <= 'z': - if upper { - h = append(h, b-toLower) - } else { - h = append(h, b) +// split compounded header values ["foo, bar", "baz"] -> ["foo", "bar", "baz"] +func splitHeaderValues(values []string) []string { + out := values + copied := false + for i, v := range values { + needsSplit := strings.IndexByte(v, ',') != -1 + if !copied { + if needsSplit { + split := strings.Split(v, ",") + out = make([]string, i, len(values)+len(split)-1) + copy(out, values[:i]) + for _, s := range split { + out = append(out, strings.TrimSpace(s)) + } + copied = true } - case b >= 'A' && b <= 'Z': - if !upper { - h = append(h, b+toLower) + } else { + if needsSplit { + split := strings.Split(v, ",") + for _, s := range split { + out = append(out, strings.TrimSpace(s)) + } } else { - h = append(h, b) + out = append(out, v) } - case b == '-' || b == '_' || b == '.' || (b >= '0' && b <= '9'): - h = append(h, b) } + } + return out +} + +// convert converts a list of string using the passed converter function +func convert(s []string, c converter) []string { + out, _ := convertDidCopy(s, c) + return out +} - if b == ' ' || b == ',' || i == l-1 { - if len(h) > 0 { - // Flush the found header - headers = append(headers, string(h)) - h = h[:0] - upper = true +// convertDidCopy is same as convert but returns true if it copied the slice +func convertDidCopy(s []string, c converter) ([]string, bool) { + out := s + copied := false + for i, v := range s { + if !copied { + v2 := c(v) + if v2 != v { + out = make([]string, len(s)) + copy(out, s[:i]) + out[i] = v2 + copied = true } } else { - upper = b == '-' || b == '_' + out[i] = c(v) } } - return headers + return out, copied } diff --git a/utils_test.go b/utils_test.go index 6475a94..f6a83ec 100644 --- a/utils_test.go +++ b/utils_test.go @@ -1,6 +1,7 @@ package cors import ( + "reflect" "strings" "testing" ) @@ -23,49 +24,46 @@ func TestWildcard(t *testing.T) { } } -func TestConvert(t *testing.T) { - s := convert([]string{"A", "b", "C"}, strings.ToLower) - e := []string{"a", "b", "c"} - if s[0] != e[0] || s[1] != e[1] || s[2] != e[2] { - t.Errorf("%v != %v", s, e) - } -} - -func TestParseHeaderList(t *testing.T) { - h := parseHeaderList("header, second-header, THIRD-HEADER, Numb3r3d-H34d3r") - e := []string{"Header", "Second-Header", "Third-Header", "Numb3r3d-H34d3r"} - if h[0] != e[0] || h[1] != e[1] || h[2] != e[2] { - t.Errorf("%v != %v", h, e) - } -} - -func TestParseHeaderListEmpty(t *testing.T) { - if len(parseHeaderList("")) != 0 { - t.Error("should be empty sclice") +func TestSplitHeaderValues(t *testing.T) { + testCases := []struct { + input []string + expected []string + }{ + { + input: []string{}, + expected: []string{}, + }, + { + input: []string{"foo"}, + expected: []string{"foo"}, + }, + { + input: []string{"foo, bar, baz"}, + expected: []string{"foo", "bar", "baz"}, + }, + { + input: []string{"abc", "def, ghi", "jkl"}, + expected: []string{"abc", "def", "ghi", "jkl"}, + }, + { + input: []string{"foo, bar", "baz, qux", "quux, corge"}, + expected: []string{"foo", "bar", "baz", "qux", "quux", "corge"}, + }, } - if len(parseHeaderList(" , ")) != 0 { - t.Error("should be empty sclice") - } -} - -func BenchmarkParseHeaderList(b *testing.B) { - b.ReportAllocs() - for i := 0; i < b.N; i++ { - parseHeaderList("header, second-header, THIRD-HEADER") - } -} -func BenchmarkParseHeaderListSingle(b *testing.B) { - b.ReportAllocs() - for i := 0; i < b.N; i++ { - parseHeaderList("header") + for _, testCase := range testCases { + output := splitHeaderValues(testCase.input) + if !reflect.DeepEqual(output, testCase.expected) { + t.Errorf("Input: %v, Expected: %v, Got: %v", testCase.input, testCase.expected, output) + } } } -func BenchmarkParseHeaderListNormalized(b *testing.B) { - b.ReportAllocs() - for i := 0; i < b.N; i++ { - parseHeaderList("Header1, Header2, Third-Header") +func TestConvert(t *testing.T) { + s := convert([]string{"A", "b", "C"}, strings.ToLower) + e := []string{"a", "b", "c"} + if s[0] != e[0] || s[1] != e[1] || s[2] != e[2] { + t.Errorf("%v != %v", s, e) } }