-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcors.go
174 lines (136 loc) · 4.63 KB
/
cors.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
package cors
import (
"net/http"
"regexp"
"strconv"
"strings"
)
// CreateMiddleware returns middleware for using in client's code based on configuration. Middleware is function that
// receives http.Handler interface and returns http.HandlerFunc function, that implements http.Handler interface.
// You can pass this function as http.HandlerFunc into http.HandleFunc. And you can pass it into http.ListenAndServe.
func CreateMiddleware(c *Config) func(http.Handler) http.HandlerFunc {
return func(next http.Handler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// If your server makes a decision about what to return based on a what’s in a HTTP header,
// you need to include that header name in your Vary, even if the request didn’t include that header.
// (https://textslashplain.com/2018/08/02/cors-and-vary/)
w.Header().Add(VaryHeader, OriginHeader)
w.Header().Add(VaryHeader, RequestMethodHeader)
w.Header().Add(VaryHeader, RequestHeadersHeader)
if !isCorsRequest(r) {
next.ServeHTTP(w, r)
} else if isPreflightRequest(r) {
handlePreflightRequest(c, w, r, next)
} else {
handleSimpleRequest(c, w, r, next)
}
}
}
}
// handleSimpleRequest handles simple cross-origin request
func handleSimpleRequest(c *Config, w http.ResponseWriter, r *http.Request, next http.Handler) {
if c.AllowAllOrigin {
w.Header().Set(AllowOriginHeader, "*")
} else if c.AllowOriginPattern != "" {
origin := r.Header.Get(OriginHeader)
match, err := regexp.MatchString(c.AllowOriginPattern, origin)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte("Origin header validation error: " + err.Error()))
return
}
if match {
w.Header().Set(AllowOriginHeader, origin)
}
} else {
next.ServeHTTP(w, r)
return
}
if c.AllowCredentials {
w.Header().Set(AllowCredentialsHeader, "true")
}
if len(c.ExposedHeaders) > 0 {
w.Header().Set(ExposeHeadersHeader, strings.Join(c.ExposedHeaders, ","))
}
next.ServeHTTP(w, r)
}
// handlePreflightRequest handles preflight cross-origin request
func handlePreflightRequest(c *Config, w http.ResponseWriter, r *http.Request, next http.Handler) {
if c.AllowAllOrigin {
w.Header().Set(AllowOriginHeader, "*")
} else if c.AllowOriginPattern != "" {
origin := r.Header.Get(OriginHeader)
match, err := regexp.MatchString(c.AllowOriginPattern, origin)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte("Origin header validation error: " + err.Error()))
return
}
if match {
w.Header().Set(AllowOriginHeader, origin)
}
} else {
next.ServeHTTP(w, r)
return
}
method := r.Header.Get(RequestMethodHeader)
if !contains(strings.ToUpper(method), c.AllowMethods) {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
if c.AllowCredentials {
w.Header().Set(AllowCredentialsHeader, "true")
}
allowMethods := c.AllowMethods
// If client sends method in not upper case we have to allow it.
if !contains(method, c.AllowMethods) {
allowMethods = append(allowMethods, method)
}
requestHeaders := r.Header.Get(RequestHeadersHeader)
if requestHeaders != "" && !c.AllowAllHeaders {
r := regexp.MustCompile(` *, *`)
headers := r.Split(strings.TrimSpace(requestHeaders), -1)
for _, h := range headers {
h = http.CanonicalHeaderKey(h)
if !contains(h, c.AllowHeaders) {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("Unauthorized header " + h))
return
}
}
}
w.Header().Set(AllowMethodsHeader, strings.Join(allowMethods, ","))
var headers string
if c.AllowAllHeaders {
headers = requestHeaders
} else {
headers = strings.Join(c.AllowHeaders, ",")
}
if headers != "" {
w.Header().Set(AllowHeadersHeader, headers)
}
if c.MaxAge > 0 {
w.Header().Set(MaxAgeHeader, strconv.Itoa(c.MaxAge))
}
if !c.ContinuousPreflight {
var status int
if c.PreflightTerminationStatus == 0 {
status = http.StatusOK
} else {
status = c.PreflightTerminationStatus
}
w.WriteHeader(status)
return
}
next.ServeHTTP(w, r)
}
// isCorsRequest checks if request is CORS
func isCorsRequest(r *http.Request) bool {
origin := r.Header.Get(OriginHeader)
host := r.Host
return !(origin == "" || origin == "http://"+host || origin == "https://"+host)
}
// isPreflightRequest checks if request is preflight
func isPreflightRequest(r *http.Request) bool {
return r.Method == http.MethodOptions && r.Header.Get(RequestMethodHeader) != ""
}