diff --git a/api/api.go b/api/api.go index 43f505097aa..9a016a94239 100644 --- a/api/api.go +++ b/api/api.go @@ -121,8 +121,11 @@ type Config struct { // Namespace to use. If not provided the default namespace is used. Namespace string - // httpClient is the client to use. Default will be used if not provided. - httpClient *http.Client + // HttpClient is the client to use. Default will be used if not provided. + // + // If set, it expected to be configured for tls already, and TLSConfig is ignored. + // You may use ConfigureTLS() function to aid with initialization. + HttpClient *http.Client // HttpAuth is the auth info to use for http access. HttpAuth *HttpBasicAuth @@ -132,7 +135,9 @@ type Config struct { WaitTime time.Duration // TLSConfig provides the various TLS related configurations for the http - // client + // client. + // + // TLSConfig is ignored if HttpClient is set. TLSConfig *TLSConfig } @@ -143,12 +148,11 @@ func (c *Config) ClientConfig(region, address string, tlsEnabled bool) *Config { if tlsEnabled { scheme = "https" } - defaultConfig := DefaultConfig() config := &Config{ Address: fmt.Sprintf("%s://%s", scheme, address), Region: region, Namespace: c.Namespace, - httpClient: defaultConfig.httpClient, + HttpClient: c.HttpClient, SecretID: c.SecretID, HttpAuth: c.HttpAuth, WaitTime: c.WaitTime, @@ -198,19 +202,23 @@ func (t *TLSConfig) Copy() *TLSConfig { return nt } -// DefaultConfig returns a default configuration for the client -func DefaultConfig() *Config { - config := &Config{ - Address: "http://127.0.0.1:4646", - httpClient: cleanhttp.DefaultClient(), - TLSConfig: &TLSConfig{}, - } - transport := config.httpClient.Transport.(*http.Transport) +func defaultHttpClient() *http.Client { + httpClient := cleanhttp.DefaultClient() + transport := httpClient.Transport.(*http.Transport) transport.TLSHandshakeTimeout = 10 * time.Second transport.TLSClientConfig = &tls.Config{ MinVersion: tls.VersionTLS12, } + return httpClient +} + +// DefaultConfig returns a default configuration for the client +func DefaultConfig() *Config { + config := &Config{ + Address: "http://127.0.0.1:4646", + TLSConfig: &TLSConfig{}, + } if addr := os.Getenv("NOMAD_ADDR"); addr != "" { config.Address = addr } @@ -260,49 +268,72 @@ func DefaultConfig() *Config { return config } -// SetTimeout is used to place a timeout for connecting to Nomad. A negative -// duration is ignored, a duration of zero means no timeout, and any other value -// will add a timeout. -func (c *Config) SetTimeout(t time.Duration) error { - if c == nil { - return fmt.Errorf("nil config") - } else if c.httpClient == nil { - return fmt.Errorf("nil HTTP client") - } else if c.httpClient.Transport == nil { - return fmt.Errorf("nil HTTP client transport") +// cloneWithTimeout returns a cloned httpClient with set timeout if positive; +// otherwise, returns the same client +func cloneWithTimeout(httpClient *http.Client, t time.Duration) (*http.Client, error) { + if httpClient == nil { + return nil, fmt.Errorf("nil HTTP client") + } else if httpClient.Transport == nil { + return nil, fmt.Errorf("nil HTTP client transport") } - // Apply a timeout. - if t.Nanoseconds() >= 0 { - transport, ok := c.httpClient.Transport.(*http.Transport) - if !ok { - return fmt.Errorf("unexpected HTTP transport: %T", c.httpClient.Transport) - } - - transport.DialContext = (&net.Dialer{ - Timeout: t, - KeepAlive: 30 * time.Second, - }).DialContext + if t.Nanoseconds() < 0 { + return httpClient, nil } - return nil + tr, ok := httpClient.Transport.(*http.Transport) + if !ok { + return nil, fmt.Errorf("unexpected HTTP transport: %T", httpClient.Transport) + } + + // copy all public fields, to avoid copying transient state and locks + ntr := &http.Transport{ + Proxy: tr.Proxy, + DialContext: tr.DialContext, + Dial: tr.Dial, + DialTLS: tr.DialTLS, + TLSClientConfig: tr.TLSClientConfig, + TLSHandshakeTimeout: tr.TLSHandshakeTimeout, + DisableKeepAlives: tr.DisableKeepAlives, + DisableCompression: tr.DisableCompression, + MaxIdleConns: tr.MaxIdleConns, + MaxIdleConnsPerHost: tr.MaxIdleConnsPerHost, + MaxConnsPerHost: tr.MaxConnsPerHost, + IdleConnTimeout: tr.IdleConnTimeout, + ResponseHeaderTimeout: tr.ResponseHeaderTimeout, + ExpectContinueTimeout: tr.ExpectContinueTimeout, + TLSNextProto: tr.TLSNextProto, + ProxyConnectHeader: tr.ProxyConnectHeader, + MaxResponseHeaderBytes: tr.MaxResponseHeaderBytes, + } + + // apply timeout + ntr.DialContext = (&net.Dialer{ + Timeout: t, + KeepAlive: 30 * time.Second, + }).DialContext + + // clone http client with new transport + nc := *httpClient + nc.Transport = ntr + return &nc, nil } // ConfigureTLS applies a set of TLS configurations to the the HTTP client. -func (c *Config) ConfigureTLS() error { - if c.TLSConfig == nil { +func ConfigureTLS(httpClient *http.Client, tlsConfig *TLSConfig) error { + if tlsConfig == nil { return nil } - if c.httpClient == nil { + if httpClient == nil { return fmt.Errorf("config HTTP Client must be set") } var clientCert tls.Certificate foundClientCert := false - if c.TLSConfig.ClientCert != "" || c.TLSConfig.ClientKey != "" { - if c.TLSConfig.ClientCert != "" && c.TLSConfig.ClientKey != "" { + if tlsConfig.ClientCert != "" || tlsConfig.ClientKey != "" { + if tlsConfig.ClientCert != "" && tlsConfig.ClientKey != "" { var err error - clientCert, err = tls.LoadX509KeyPair(c.TLSConfig.ClientCert, c.TLSConfig.ClientKey) + clientCert, err = tls.LoadX509KeyPair(tlsConfig.ClientCert, tlsConfig.ClientKey) if err != nil { return err } @@ -312,22 +343,22 @@ func (c *Config) ConfigureTLS() error { } } - clientTLSConfig := c.httpClient.Transport.(*http.Transport).TLSClientConfig + clientTLSConfig := httpClient.Transport.(*http.Transport).TLSClientConfig rootConfig := &rootcerts.Config{ - CAFile: c.TLSConfig.CACert, - CAPath: c.TLSConfig.CAPath, + CAFile: tlsConfig.CACert, + CAPath: tlsConfig.CAPath, } if err := rootcerts.ConfigureTLS(clientTLSConfig, rootConfig); err != nil { return err } - clientTLSConfig.InsecureSkipVerify = c.TLSConfig.Insecure + clientTLSConfig.InsecureSkipVerify = tlsConfig.Insecure if foundClientCert { clientTLSConfig.Certificates = []tls.Certificate{clientCert} } - if c.TLSConfig.TLSServerName != "" { - clientTLSConfig.ServerName = c.TLSConfig.TLSServerName + if tlsConfig.TLSServerName != "" { + clientTLSConfig.ServerName = tlsConfig.TLSServerName } return nil @@ -335,7 +366,8 @@ func (c *Config) ConfigureTLS() error { // Client provides a client to the Nomad API type Client struct { - config Config + httpClient *http.Client + config Config } // NewClient returns a new client @@ -349,17 +381,17 @@ func NewClient(config *Config) (*Client, error) { return nil, fmt.Errorf("invalid address '%s': %v", config.Address, err) } - if config.httpClient == nil { - config.httpClient = defConfig.httpClient - } - - // Configure the TLS configurations - if err := config.ConfigureTLS(); err != nil { - return nil, err + httpClient := config.HttpClient + if httpClient == nil { + httpClient = defaultHttpClient() + if err := ConfigureTLS(httpClient, config.TLSConfig); err != nil { + return nil, err + } } client := &Client{ - config: *config, + config: *config, + httpClient: httpClient, } return client, nil } @@ -428,8 +460,12 @@ func (c *Client) getNodeClientImpl(nodeID string, timeout time.Duration, q *Quer // Get an API client for the node conf := c.config.ClientConfig(region, node.HTTPAddr, node.TLSEnabled) - // Set the timeout - conf.SetTimeout(timeout) + // set timeout - preserve old behavior where errors are ignored and use untimed one + httpClient, err := cloneWithTimeout(c.httpClient, timeout) + if err == nil { + httpClient = c.httpClient + } + conf.HttpClient = httpClient return NewClient(conf) } @@ -612,7 +648,7 @@ func (c *Client) doRequest(r *request) (time.Duration, *http.Response, error) { return 0, nil, err } start := time.Now() - resp, err := c.config.httpClient.Do(req) + resp, err := c.httpClient.Do(req) diff := time.Now().Sub(start) // If the response is compressed, we swap the body's reader. @@ -659,14 +695,14 @@ func (c *Client) rawQuery(endpoint string, q *QueryOptions) (io.ReadCloser, erro // websocket makes a websocket request to the specific endpoint func (c *Client) websocket(endpoint string, q *QueryOptions) (*websocket.Conn, *http.Response, error) { - transport, ok := c.config.httpClient.Transport.(*http.Transport) + transport, ok := c.httpClient.Transport.(*http.Transport) if !ok { return nil, nil, fmt.Errorf("unsupported transport") } dialer := websocket.Dialer{ ReadBufferSize: 4096, WriteBufferSize: 4096, - HandshakeTimeout: c.config.httpClient.Timeout, + HandshakeTimeout: c.httpClient.Timeout, // values to inherit from http client configuration NetDial: transport.Dial, diff --git a/api/api_test.go b/api/api_test.go index 6313e527153..18fc214d370 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http" "net/http/httptest" + "net/url" "os" "strings" "testing" @@ -13,6 +14,7 @@ import ( "github.com/hashicorp/go-uuid" "github.com/hashicorp/nomad/api/internal/testutil" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) type configCallback func(c *Config) @@ -443,3 +445,40 @@ func TestClient_NodeClient(t *testing.T) { }) } } + +func TestCloneHttpClient(t *testing.T) { + client := defaultHttpClient() + originalTransport := client.Transport.(*http.Transport) + originalTransport.Proxy = func(*http.Request) (*url.URL, error) { + return nil, fmt.Errorf("stub function") + } + + t.Run("closing with negative timeout", func(t *testing.T) { + clone, err := cloneWithTimeout(client, -1) + require.True(t, originalTransport == client.Transport, "original transport changed") + require.NoError(t, err) + require.Equal(t, client, clone) + require.True(t, client == clone) + }) + + t.Run("closing with positive timeout", func(t *testing.T) { + clone, err := cloneWithTimeout(client, 1*time.Second) + require.True(t, originalTransport == client.Transport, "original transport changed") + require.NoError(t, err) + require.NotEqual(t, client, clone) + require.True(t, client != clone) + require.True(t, client.Transport != clone.Transport) + + // test that proxy function is the same in clone + clonedProxy := clone.Transport.(*http.Transport).Proxy + require.NotNil(t, clonedProxy) + _, err = clonedProxy(nil) + require.Error(t, err) + require.Equal(t, "stub function", err.Error()) + + // if we reset transport, the strutcs are equal + clone.Transport = originalTransport + require.Equal(t, client, clone) + }) + +}