Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

api: allow configuring http client #5275

Merged
merged 2 commits into from
May 20, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 108 additions & 60 deletions api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,11 @@ type Config struct {
// Namespace to use. If not provided the default namespace is used.
Namespace string

// httpClient is the client to use. Default will be used if not provided.
httpClient *http.Client
// HttpClient is the client to use. Default will be used if not provided.
//
// If set, it expected to be configured for tls already, and TLSConfig is ignored.
// You may use ConfigureTLS() function to aid with initialization.
HttpClient *http.Client

// HttpAuth is the auth info to use for http access.
HttpAuth *HttpBasicAuth
Expand All @@ -132,7 +135,9 @@ type Config struct {
WaitTime time.Duration

// TLSConfig provides the various TLS related configurations for the http
// client
// client.
//
// TLSConfig is ignored if HttpClient is set.
TLSConfig *TLSConfig
}

Expand All @@ -143,12 +148,11 @@ func (c *Config) ClientConfig(region, address string, tlsEnabled bool) *Config {
if tlsEnabled {
scheme = "https"
}
defaultConfig := DefaultConfig()
config := &Config{
Address: fmt.Sprintf("%s://%s", scheme, address),
Region: region,
Namespace: c.Namespace,
httpClient: defaultConfig.httpClient,
HttpClient: c.HttpClient,
SecretID: c.SecretID,
HttpAuth: c.HttpAuth,
WaitTime: c.WaitTime,
Expand Down Expand Up @@ -198,19 +202,23 @@ func (t *TLSConfig) Copy() *TLSConfig {
return nt
}

// DefaultConfig returns a default configuration for the client
func DefaultConfig() *Config {
config := &Config{
Address: "http://127.0.0.1:4646",
httpClient: cleanhttp.DefaultClient(),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously, DefaultConfig initializes httpClient for client to use. This PR changes it so default httpClient creation occurs in NewClient instead and have it be a field of the api client rather than config.

TLSConfig: &TLSConfig{},
}
transport := config.httpClient.Transport.(*http.Transport)
func defaultHttpClient() *http.Client {
httpClient := cleanhttp.DefaultClient()
transport := httpClient.Transport.(*http.Transport)
transport.TLSHandshakeTimeout = 10 * time.Second
transport.TLSClientConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
}

return httpClient
}

// DefaultConfig returns a default configuration for the client
func DefaultConfig() *Config {
config := &Config{
Address: "http://127.0.0.1:4646",
TLSConfig: &TLSConfig{},
}
if addr := os.Getenv("NOMAD_ADDR"); addr != "" {
config.Address = addr
}
Expand Down Expand Up @@ -260,49 +268,84 @@ func DefaultConfig() *Config {
return config
}

// SetTimeout is used to place a timeout for connecting to Nomad. A negative
// duration is ignored, a duration of zero means no timeout, and any other value
// will add a timeout.
func (c *Config) SetTimeout(t time.Duration) error {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed this function, as I believe we only intend to use internally for creating client connections and it's hard for me to imagine someone using it as-is now for other purposes; and if they do, they can set their http client.

If we don't want to remove a public function here, I can reintroduce it but with adding some additional bookkeeping to discern on applying tls config whether it's to default client with timeout or a custom http client.

if c == nil {
return fmt.Errorf("nil config")
} else if c.httpClient == nil {
return fmt.Errorf("nil HTTP client")
} else if c.httpClient.Transport == nil {
return fmt.Errorf("nil HTTP client transport")
// cloneWithTimeout returns a cloned httpClient with set timeout if positive;
// otherwise, returns the same client
func cloneWithTimeout(httpClient *http.Client, t time.Duration) (*http.Client, error) {
if httpClient == nil {
return nil, fmt.Errorf("nil HTTP client")
} else if httpClient.Transport == nil {
return nil, fmt.Errorf("nil HTTP client transport")
}

// Apply a timeout.
if t.Nanoseconds() >= 0 {
transport, ok := c.httpClient.Transport.(*http.Transport)
if !ok {
return fmt.Errorf("unexpected HTTP transport: %T", c.httpClient.Transport)
}

transport.DialContext = (&net.Dialer{
Timeout: t,
KeepAlive: 30 * time.Second,
}).DialContext
if t.Nanoseconds() < 0 {
return httpClient, nil
}

return nil
tr, ok := httpClient.Transport.(*http.Transport)
if !ok {
return nil, fmt.Errorf("unexpected HTTP transport: %T", httpClient.Transport)
}

// copy all public fields, to avoid copying transient state and locks
ntr := &http.Transport{
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, I want to preserve as much of the http client the user provided (including proxies, etc). Doing copying seems to copy transient state and locks; so opted to copy over all fields explicitly. Is there a better way?

Alternatively, we can have users set a HttpClient constructor rather than a simple HttpClient, so we can modify instances easily. I opted not to, because I felt it complicated the API too much.

Proxy: tr.Proxy,
DialContext: tr.DialContext,
Dial: tr.Dial,
DialTLS: tr.DialTLS,
TLSClientConfig: tr.TLSClientConfig,
TLSHandshakeTimeout: tr.TLSHandshakeTimeout,
DisableKeepAlives: tr.DisableKeepAlives,
DisableCompression: tr.DisableCompression,
MaxIdleConns: tr.MaxIdleConns,
MaxIdleConnsPerHost: tr.MaxIdleConnsPerHost,
MaxConnsPerHost: tr.MaxConnsPerHost,
IdleConnTimeout: tr.IdleConnTimeout,
ResponseHeaderTimeout: tr.ResponseHeaderTimeout,
ExpectContinueTimeout: tr.ExpectContinueTimeout,
TLSNextProto: tr.TLSNextProto,
ProxyConnectHeader: tr.ProxyConnectHeader,
MaxResponseHeaderBytes: tr.MaxResponseHeaderBytes,
}

// apply timeout
ntr.DialContext = (&net.Dialer{
Timeout: t,
KeepAlive: 30 * time.Second,
}).DialContext

// clone http client with new transport
nc := *httpClient
nc.Transport = ntr
return &nc, nil
}

// ConfigureTLS applies a set of TLS configurations to the the HTTP client.
//
// Deprecated: This method is called internally. Consider using ConfigureTLS instead.
Copy link
Member

@schmichael schmichael May 20, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider using ConfigureTLS instead.

This method is ConfigureTLS. Should we just unexport it since we're already changing the API?

_Update: Of course 3s later I spot the ConfigureTLS func 😅. I think it's probably worth adding "func" to the comment or something though -- or unexport it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll just remove it then. API consumers can just remove invocation without any change in behavior as it's called in NewClient anyway.

func (c *Config) ConfigureTLS() error {
if c.TLSConfig == nil {

// preserve backward behavior where ConfigureTLS pre0.9 always had a client
if c.HttpClient == nil {
c.HttpClient = defaultHttpClient()
}
return ConfigureTLS(c.HttpClient, c.TLSConfig)
}

// ConfigureTLS applies a set of TLS configurations to the the HTTP client.
func ConfigureTLS(httpClient *http.Client, tlsConfig *TLSConfig) error {
if tlsConfig == nil {
return nil
}
if c.httpClient == nil {
if httpClient == nil {
return fmt.Errorf("config HTTP Client must be set")
}

var clientCert tls.Certificate
foundClientCert := false
if c.TLSConfig.ClientCert != "" || c.TLSConfig.ClientKey != "" {
if c.TLSConfig.ClientCert != "" && c.TLSConfig.ClientKey != "" {
if tlsConfig.ClientCert != "" || tlsConfig.ClientKey != "" {
if tlsConfig.ClientCert != "" && tlsConfig.ClientKey != "" {
var err error
clientCert, err = tls.LoadX509KeyPair(c.TLSConfig.ClientCert, c.TLSConfig.ClientKey)
clientCert, err = tls.LoadX509KeyPair(tlsConfig.ClientCert, tlsConfig.ClientKey)
if err != nil {
return err
}
Expand All @@ -312,30 +355,31 @@ func (c *Config) ConfigureTLS() error {
}
}

clientTLSConfig := c.httpClient.Transport.(*http.Transport).TLSClientConfig
clientTLSConfig := httpClient.Transport.(*http.Transport).TLSClientConfig
rootConfig := &rootcerts.Config{
CAFile: c.TLSConfig.CACert,
CAPath: c.TLSConfig.CAPath,
CAFile: tlsConfig.CACert,
CAPath: tlsConfig.CAPath,
}
if err := rootcerts.ConfigureTLS(clientTLSConfig, rootConfig); err != nil {
return err
}

clientTLSConfig.InsecureSkipVerify = c.TLSConfig.Insecure
clientTLSConfig.InsecureSkipVerify = tlsConfig.Insecure

if foundClientCert {
clientTLSConfig.Certificates = []tls.Certificate{clientCert}
}
if c.TLSConfig.TLSServerName != "" {
clientTLSConfig.ServerName = c.TLSConfig.TLSServerName
if tlsConfig.TLSServerName != "" {
clientTLSConfig.ServerName = tlsConfig.TLSServerName
}

return nil
}

// Client provides a client to the Nomad API
type Client struct {
config Config
httpClient *http.Client
config Config
}

// NewClient returns a new client
Expand All @@ -349,17 +393,17 @@ func NewClient(config *Config) (*Client, error) {
return nil, fmt.Errorf("invalid address '%s': %v", config.Address, err)
}

if config.httpClient == nil {
config.httpClient = defConfig.httpClient
}

// Configure the TLS configurations
if err := config.ConfigureTLS(); err != nil {
return nil, err
httpClient := config.HttpClient
if httpClient == nil {
httpClient = defaultHttpClient()
if err := ConfigureTLS(httpClient, config.TLSConfig); err != nil {
return nil, err
}
}

client := &Client{
config: *config,
config: *config,
httpClient: httpClient,
}
return client, nil
}
Expand Down Expand Up @@ -428,8 +472,12 @@ func (c *Client) getNodeClientImpl(nodeID string, timeout time.Duration, q *Quer
// Get an API client for the node
conf := c.config.ClientConfig(region, node.HTTPAddr, node.TLSEnabled)

// Set the timeout
conf.SetTimeout(timeout)
// set timeout - preserve old behavior where errors are ignored and use untimed one
httpClient, err := cloneWithTimeout(c.httpClient, timeout)
if err == nil {
httpClient = c.httpClient
}
conf.HttpClient = httpClient

return NewClient(conf)
}
Expand Down Expand Up @@ -612,7 +660,7 @@ func (c *Client) doRequest(r *request) (time.Duration, *http.Response, error) {
return 0, nil, err
}
start := time.Now()
resp, err := c.config.httpClient.Do(req)
resp, err := c.httpClient.Do(req)
diff := time.Now().Sub(start)

// If the response is compressed, we swap the body's reader.
Expand Down Expand Up @@ -659,14 +707,14 @@ func (c *Client) rawQuery(endpoint string, q *QueryOptions) (io.ReadCloser, erro
// websocket makes a websocket request to the specific endpoint
func (c *Client) websocket(endpoint string, q *QueryOptions) (*websocket.Conn, *http.Response, error) {

transport, ok := c.config.httpClient.Transport.(*http.Transport)
transport, ok := c.httpClient.Transport.(*http.Transport)
if !ok {
return nil, nil, fmt.Errorf("unsupported transport")
}
dialer := websocket.Dialer{
ReadBufferSize: 4096,
WriteBufferSize: 4096,
HandshakeTimeout: c.config.httpClient.Timeout,
HandshakeTimeout: c.httpClient.Timeout,

// values to inherit from http client configuration
NetDial: transport.Dial,
Expand Down
39 changes: 39 additions & 0 deletions api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"os"
"strings"
"testing"
Expand All @@ -13,6 +14,7 @@ import (
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/nomad/api/internal/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

type configCallback func(c *Config)
Expand Down Expand Up @@ -443,3 +445,40 @@ func TestClient_NodeClient(t *testing.T) {
})
}
}

func TestCloneHttpClient(t *testing.T) {
client := defaultHttpClient()
originalTransport := client.Transport.(*http.Transport)
originalTransport.Proxy = func(*http.Request) (*url.URL, error) {
return nil, fmt.Errorf("stub function")
}

t.Run("closing with negative timeout", func(t *testing.T) {
clone, err := cloneWithTimeout(client, -1)
require.True(t, originalTransport == client.Transport, "original transport changed")
require.NoError(t, err)
require.Equal(t, client, clone)
require.True(t, client == clone)
})

t.Run("closing with positive timeout", func(t *testing.T) {
clone, err := cloneWithTimeout(client, 1*time.Second)
require.True(t, originalTransport == client.Transport, "original transport changed")
require.NoError(t, err)
require.NotEqual(t, client, clone)
require.True(t, client != clone)
require.True(t, client.Transport != clone.Transport)

// test that proxy function is the same in clone
clonedProxy := clone.Transport.(*http.Transport).Proxy
require.NotNil(t, clonedProxy)
_, err = clonedProxy(nil)
require.Error(t, err)
require.Equal(t, "stub function", err.Error())

// if we reset transport, the strutcs are equal
clone.Transport = originalTransport
require.Equal(t, client, clone)
})

}