Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make URL parsing logic match steve #346

Merged
merged 1 commit into from
Feb 12, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 60 additions & 72 deletions urlbuilder/url.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package urlbuilder
import (
"bytes"
"fmt"
"net"
"net/http"
"net/url"
"strings"
Expand All @@ -12,14 +13,15 @@ import (
)

const (
PrefixHeader = "X-API-URL-Prefix"
ForwardedHostHeader = "X-Forwarded-Host"
ForwardedProtoHeader = "X-Forwarded-Proto"
ForwardedPortHeader = "X-Forwarded-Port"
PrefixHeader = "X-API-URL-Prefix"
ForwardedAPIHostHeader = "X-API-Host"
ForwardedHostHeader = "X-Forwarded-Host"
ForwardedProtoHeader = "X-Forwarded-Proto"
ForwardedPortHeader = "X-Forwarded-Port"
)

func New(r *http.Request, version types.APIVersion, schemas *types.Schemas) (types.URLBuilder, error) {
requestURL := parseRequestURL(r)
requestURL := ParseRequestURL(r)
responseURLBase, err := parseResponseURLBase(requestURL, r)
if err != nil {
return nil, err
Expand All @@ -36,6 +38,59 @@ func New(r *http.Request, version types.APIVersion, schemas *types.Schemas) (typ
return builder, nil
}

func ParseRequestURL(r *http.Request) string {
scheme := GetScheme(r)
host := GetHost(r, scheme)
return fmt.Sprintf("%s://%s%s%s", scheme, host, r.Header.Get(PrefixHeader), r.URL.Path)
}

func GetHost(r *http.Request, scheme string) string {
host := r.Header.Get(ForwardedAPIHostHeader)
if host == "" {
host = strings.Split(r.Header.Get(ForwardedHostHeader), ",")[0]
}
if host == "" {
host = r.Host
}

port := r.Header.Get(ForwardedPortHeader)
if port == "" {
return host
}

if port == "80" && scheme == "http" {
return host
}

if port == "443" && scheme == "http" {
return host
}

hostname, _, err := net.SplitHostPort(host)
if err != nil {
return host
}

return strings.Join([]string{hostname, port}, ":")
}

func GetScheme(r *http.Request) string {
scheme := r.Header.Get(ForwardedProtoHeader)
if scheme != "" {
switch scheme {
case "ws":
return "http"
case "wss":
return "https"
default:
return scheme
}
} else if r.TLS != nil {
return "https"
}
return "http"
}

type urlBuilder struct {
schemas *types.Schemas
requestURL string
Expand Down Expand Up @@ -171,73 +226,6 @@ func (u *urlBuilder) getPluralName(schema *types.Schema) string {
return strings.ToLower(schema.PluralName)
}

// Constructs the request URL based off of standard headers in the request, falling back to the HttpServletRequest.getRequestURL()
// if the headers aren't available. Here is the ordered list of how we'll attempt to construct the URL:
// - x-forwarded-proto://x-forwarded-host:x-forwarded-port/HttpServletRequest.getRequestURI()
// - x-forwarded-proto://x-forwarded-host/HttpServletRequest.getRequestURI()
// - x-forwarded-proto://host:x-forwarded-port/HttpServletRequest.getRequestURI()
// - x-forwarded-proto://host/HttpServletRequest.getRequestURI() request.getRequestURL()
//
// Additional notes:
// - If the x-forwarded-host/host header has a port and x-forwarded-port has been passed, x-forwarded-port will be used.
func parseRequestURL(r *http.Request) string {
// Get url from standard headers
requestURL := getURLFromStandardHeaders(r)
if requestURL != "" {
return requestURL
}

// Use incoming url
scheme := "http"
if r.TLS != nil {
scheme = "https"
}
return fmt.Sprintf("%s://%s%s%s", scheme, r.Host, r.Header.Get(PrefixHeader), r.URL.Path)
}

func getURLFromStandardHeaders(r *http.Request) string {
xForwardedProto := getOverrideHeader(r, ForwardedProtoHeader, "")
if xForwardedProto == "" {
return ""
}

host := getOverrideHeader(r, ForwardedHostHeader, "")
if host == "" {
host = r.Host
}

if host == "" {
return ""
}

port := getOverrideHeader(r, ForwardedPortHeader, "")
if port == "443" || port == "80" {
port = "" // Don't include default ports in url
}

if port != "" && strings.Contains(host, ":") {
// Have to strip the port that is in the host. Handle IPv6, which has this format: [::1]:8080
if (strings.HasPrefix(host, "[") && strings.Contains(host, "]:")) || !strings.HasPrefix(host, "[") {
host = host[0:strings.LastIndex(host, ":")]
}
}

if port != "" {
port = ":" + port
}

return fmt.Sprintf("%s://%s%s%s%s", xForwardedProto, host, port, r.Header.Get(PrefixHeader), r.URL.Path)
}

func getOverrideHeader(r *http.Request, header string, defaultValue string) string {
// Need to handle comma separated hosts in X-Forwarded-For
value := r.Header.Get(header)
if value != "" {
return strings.TrimSpace(strings.Split(value, ",")[0])
}
return defaultValue
}

func parseResponseURLBase(requestURL string, r *http.Request) (string, error) {
path := r.URL.Path

Expand Down