diff --git a/urlbuilder/url.go b/urlbuilder/url.go index 2f04f0fec..d72d8f261 100644 --- a/urlbuilder/url.go +++ b/urlbuilder/url.go @@ -3,6 +3,7 @@ package urlbuilder import ( "bytes" "fmt" + "net" "net/http" "net/url" "strings" @@ -12,14 +13,15 @@ import ( ) const ( - PrefixHeader = "X-API-URL-Prefix" - ForwardedHostHeader = "X-Forwarded-Host" - ForwardedProtoHeader = "X-Forwarded-Proto" - ForwardedPortHeader = "X-Forwarded-Port" + PrefixHeader = "X-API-URL-Prefix" + ForwardedAPIHostHeader = "X-API-Host" + ForwardedHostHeader = "X-Forwarded-Host" + ForwardedProtoHeader = "X-Forwarded-Proto" + ForwardedPortHeader = "X-Forwarded-Port" ) func New(r *http.Request, version types.APIVersion, schemas *types.Schemas) (types.URLBuilder, error) { - requestURL := parseRequestURL(r) + requestURL := ParseRequestURL(r) responseURLBase, err := parseResponseURLBase(requestURL, r) if err != nil { return nil, err @@ -36,6 +38,59 @@ func New(r *http.Request, version types.APIVersion, schemas *types.Schemas) (typ return builder, nil } +func ParseRequestURL(r *http.Request) string { + scheme := GetScheme(r) + host := GetHost(r, scheme) + return fmt.Sprintf("%s://%s%s%s", scheme, host, r.Header.Get(PrefixHeader), r.URL.Path) +} + +func GetHost(r *http.Request, scheme string) string { + host := r.Header.Get(ForwardedAPIHostHeader) + if host == "" { + host = strings.Split(r.Header.Get(ForwardedHostHeader), ",")[0] + } + if host == "" { + host = r.Host + } + + port := r.Header.Get(ForwardedPortHeader) + if port == "" { + return host + } + + if port == "80" && scheme == "http" { + return host + } + + if port == "443" && scheme == "http" { + return host + } + + hostname, _, err := net.SplitHostPort(host) + if err != nil { + return host + } + + return strings.Join([]string{hostname, port}, ":") +} + +func GetScheme(r *http.Request) string { + scheme := r.Header.Get(ForwardedProtoHeader) + if scheme != "" { + switch scheme { + case "ws": + return "http" + case "wss": + return "https" + default: + return scheme + } + } else if r.TLS != nil { + return "https" + } + return "http" +} + type urlBuilder struct { schemas *types.Schemas requestURL string @@ -171,73 +226,6 @@ func (u *urlBuilder) getPluralName(schema *types.Schema) string { return strings.ToLower(schema.PluralName) } -// Constructs the request URL based off of standard headers in the request, falling back to the HttpServletRequest.getRequestURL() -// if the headers aren't available. Here is the ordered list of how we'll attempt to construct the URL: -// - x-forwarded-proto://x-forwarded-host:x-forwarded-port/HttpServletRequest.getRequestURI() -// - x-forwarded-proto://x-forwarded-host/HttpServletRequest.getRequestURI() -// - x-forwarded-proto://host:x-forwarded-port/HttpServletRequest.getRequestURI() -// - x-forwarded-proto://host/HttpServletRequest.getRequestURI() request.getRequestURL() -// -// Additional notes: -// - If the x-forwarded-host/host header has a port and x-forwarded-port has been passed, x-forwarded-port will be used. -func parseRequestURL(r *http.Request) string { - // Get url from standard headers - requestURL := getURLFromStandardHeaders(r) - if requestURL != "" { - return requestURL - } - - // Use incoming url - scheme := "http" - if r.TLS != nil { - scheme = "https" - } - return fmt.Sprintf("%s://%s%s%s", scheme, r.Host, r.Header.Get(PrefixHeader), r.URL.Path) -} - -func getURLFromStandardHeaders(r *http.Request) string { - xForwardedProto := getOverrideHeader(r, ForwardedProtoHeader, "") - if xForwardedProto == "" { - return "" - } - - host := getOverrideHeader(r, ForwardedHostHeader, "") - if host == "" { - host = r.Host - } - - if host == "" { - return "" - } - - port := getOverrideHeader(r, ForwardedPortHeader, "") - if port == "443" || port == "80" { - port = "" // Don't include default ports in url - } - - if port != "" && strings.Contains(host, ":") { - // Have to strip the port that is in the host. Handle IPv6, which has this format: [::1]:8080 - if (strings.HasPrefix(host, "[") && strings.Contains(host, "]:")) || !strings.HasPrefix(host, "[") { - host = host[0:strings.LastIndex(host, ":")] - } - } - - if port != "" { - port = ":" + port - } - - return fmt.Sprintf("%s://%s%s%s%s", xForwardedProto, host, port, r.Header.Get(PrefixHeader), r.URL.Path) -} - -func getOverrideHeader(r *http.Request, header string, defaultValue string) string { - // Need to handle comma separated hosts in X-Forwarded-For - value := r.Header.Get(header) - if value != "" { - return strings.TrimSpace(strings.Split(value, ",")[0]) - } - return defaultValue -} - func parseResponseURLBase(requestURL string, r *http.Request) (string, error) { path := r.URL.Path