Skip to content

Commit

Permalink
http_config: Add host (#549)
Browse files Browse the repository at this point in the history
* http_config: Add host

---------

Signed-off-by: Jan-Otto Kröpke <[email protected]>
Signed-off-by: Jan-Otto Kröpke <[email protected]>
Co-authored-by: Ben Kochie <[email protected]>
  • Loading branch information
jkroepke and SuperQ authored Feb 28, 2024
1 parent 699b115 commit 52e512c
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 1 deletion.
38 changes: 38 additions & 0 deletions config/http_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,9 @@ type HTTPClientConfig struct {
// The omitempty flag is not set, because it would be hidden from the
// marshalled configuration when set to false.
EnableHTTP2 bool `yaml:"enable_http2" json:"enable_http2"`
// Host optionally overrides the Host header to send.
// If empty, the host from the URL will be used.
Host string `yaml:"host,omitempty" json:"host,omitempty"`
// Proxy configuration.
ProxyConfig `yaml:",inline"`
}
Expand Down Expand Up @@ -427,6 +430,7 @@ type httpClientOptions struct {
http2Enabled bool
idleConnTimeout time.Duration
userAgent string
host string
}

// HTTPClientOption defines an option that can be applied to the HTTP client.
Expand Down Expand Up @@ -467,6 +471,13 @@ func WithUserAgent(ua string) HTTPClientOption {
}
}

// WithHost allows setting the host header.
func WithHost(host string) HTTPClientOption {
return func(opts *httpClientOptions) {
opts.host = host
}
}

// NewClient returns a http.Client using the specified http.RoundTripper.
func newClient(rt http.RoundTripper) *http.Client {
return &http.Client{Transport: rt}
Expand Down Expand Up @@ -568,6 +579,10 @@ func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, optFuncs ...HT
rt = NewUserAgentRoundTripper(opts.userAgent, rt)
}

if opts.host != "" {
rt = NewHostRoundTripper(opts.host, rt)
}

// Return a new configured RoundTripper.
return rt, nil
}
Expand Down Expand Up @@ -1164,11 +1179,21 @@ type userAgentRoundTripper struct {
rt http.RoundTripper
}

type hostRoundTripper struct {
host string
rt http.RoundTripper
}

// NewUserAgentRoundTripper adds the user agent every request header.
func NewUserAgentRoundTripper(userAgent string, rt http.RoundTripper) http.RoundTripper {
return &userAgentRoundTripper{userAgent, rt}
}

// NewHostRoundTripper sets the [http.Request.Host] of every request.
func NewHostRoundTripper(host string, rt http.RoundTripper) http.RoundTripper {
return &hostRoundTripper{host, rt}
}

func (rt *userAgentRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
req = cloneRequest(req)
req.Header.Set("User-Agent", rt.userAgent)
Expand All @@ -1181,6 +1206,19 @@ func (rt *userAgentRoundTripper) CloseIdleConnections() {
}
}

func (rt *hostRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
req = cloneRequest(req)
req.Host = rt.host
req.Header.Set("Host", rt.host)
return rt.rt.RoundTrip(req)
}

func (rt *hostRoundTripper) CloseIdleConnections() {
if ci, ok := rt.rt.(closeIdler); ok {
ci.CloseIdleConnections()
}
}

func (c HTTPClientConfig) String() string {
b, err := yaml.Marshal(c)
if err != nil {
Expand Down
28 changes: 27 additions & 1 deletion config/http_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import (
"testing"
"time"

yaml "gopkg.in/yaml.v2"
"gopkg.in/yaml.v2"
)

const (
Expand Down Expand Up @@ -1671,6 +1671,32 @@ func TestOAuth2UserAgent(t *testing.T) {
}
}

func TestHost(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Host != "localhost.localdomain" {
t.Fatalf("Expected Host header in request to be 'localhost.localdomain', got '%s'", r.Host)
}

w.Header().Add("Content-Type", "application/json")
}))
defer ts.Close()

config := DefaultHTTPClientConfig

rt, err := NewRoundTripperFromConfig(config, "test_host", WithHost("localhost.localdomain"))
if err != nil {
t.Fatal(err)
}

client := http.Client{
Transport: rt,
}
_, err = client.Get(ts.URL)
if err != nil {
t.Fatal(err)
}
}

func TestOAuth2WithFile(t *testing.T) {
var expectedAuth string
ts := newTestOAuthServer(t, &expectedAuth)
Expand Down

0 comments on commit 52e512c

Please sign in to comment.