Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix TLSServerName for Node API Client #3127

Merged
merged 8 commits into from
Aug 29, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 39 additions & 15 deletions api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,8 @@ type Config struct {
// Region to use. If not provided, the default agent region is used.
Region string

// HttpClient is the client to use. Default will be
// used if not provided.
HttpClient *http.Client
// httpClient is the client to use. Default will be used if not provided.
httpClient *http.Client

// HttpAuth is the auth info to use for http access.
HttpAuth *HttpBasicAuth
Expand All @@ -117,15 +116,18 @@ func (c *Config) ClientConfig(region, address string, tlsEnabled bool) *Config {
if tlsEnabled {
scheme = "https"
}
defaultConfig := DefaultConfig()
config := &Config{
Address: fmt.Sprintf("%s://%s", scheme, address),
Region: region,
HttpClient: c.HttpClient,
httpClient: defaultConfig.httpClient,
HttpAuth: c.HttpAuth,
WaitTime: c.WaitTime,
TLSConfig: c.TLSConfig.Copy(),
}
config.TLSConfig.TLSServerName = fmt.Sprintf("client.%s.nomad", c.Region)
if tlsEnabled && config.TLSConfig != nil {
config.TLSConfig.TLSServerName = fmt.Sprintf("client.%s.nomad", region)
}

return config
}
Expand Down Expand Up @@ -169,10 +171,10 @@ func (t *TLSConfig) Copy() *TLSConfig {
func DefaultConfig() *Config {
config := &Config{
Address: "http://127.0.0.1:4646",
HttpClient: cleanhttp.DefaultClient(),
httpClient: cleanhttp.DefaultClient(),
TLSConfig: &TLSConfig{},
}
transport := config.HttpClient.Transport.(*http.Transport)
transport := config.httpClient.Transport.(*http.Transport)
transport.TLSHandshakeTimeout = 10 * time.Second
transport.TLSClientConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
Expand Down Expand Up @@ -221,7 +223,10 @@ func DefaultConfig() *Config {

// ConfigureTLS applies a set of TLS configurations to the the HTTP client.
func (c *Config) ConfigureTLS() error {
if c.HttpClient == nil {
if c.TLSConfig == nil {
return nil
}
if c.httpClient == nil {
return fmt.Errorf("config HTTP Client must be set")
}

Expand All @@ -240,7 +245,7 @@ func (c *Config) ConfigureTLS() error {
}
}

clientTLSConfig := c.HttpClient.Transport.(*http.Transport).TLSClientConfig
clientTLSConfig := c.httpClient.Transport.(*http.Transport).TLSClientConfig
rootConfig := &rootcerts.Config{
CAFile: c.TLSConfig.CACert,
CAPath: c.TLSConfig.CAPath,
Expand Down Expand Up @@ -277,8 +282,8 @@ func NewClient(config *Config) (*Client, error) {
return nil, fmt.Errorf("invalid address '%s': %v", config.Address, err)
}

if config.HttpClient == nil {
config.HttpClient = defConfig.HttpClient
if config.httpClient == nil {
config.httpClient = defConfig.httpClient
}

// Configure the TLS cofigurations
Expand All @@ -300,7 +305,18 @@ func (c *Client) SetRegion(region string) {
// GetNodeClient returns a new Client that will dial the specified node. If the
// QueryOptions is set, its region will be used.
func (c *Client) GetNodeClient(nodeID string, q *QueryOptions) (*Client, error) {
node, _, err := c.Nodes().Info(nodeID, q)
return c.getNodeClientImpl(nodeID, q, c.Nodes().Info)
}

// nodeLookup is the definition of a function used to lookup a node. This is
// largely used to mock the lookup in tests.
type nodeLookup func(nodeID string, q *QueryOptions) (*Node, *QueryMeta, error)

// getNodeClientImpl is the implementation of creating a API client for
// contacting a node. It takes a function to lookup the node such that it can be
// mocked during tests.
func (c *Client) getNodeClientImpl(nodeID string, q *QueryOptions, lookup nodeLookup) (*Client, error) {
node, _, err := lookup(nodeID, q)
if err != nil {
return nil, err
}
Expand All @@ -311,9 +327,17 @@ func (c *Client) GetNodeClient(nodeID string, q *QueryOptions) (*Client, error)
return nil, fmt.Errorf("http addr of node %q (%s) is not advertised", node.Name, nodeID)
}

region := c.config.Region
if q != nil && q.Region != "" {
var region string
switch {
case q != nil && q.Region != "":
// Prefer the region set in the query parameter
region = q.Region
case c.config.Region != "":
// If the client is configured for a particular region use that
region = c.config.Region
default:
// No region information is given so use the default.
region = "global"
}

// Get an API client for the node
Expand Down Expand Up @@ -471,7 +495,7 @@ func (c *Client) doRequest(r *request) (time.Duration, *http.Response, error) {
return 0, nil, err
}
start := time.Now()
resp, err := c.config.HttpClient.Do(req)
resp, err := c.config.httpClient.Do(req)
diff := time.Now().Sub(start)

// If the response is compressed, we swap the body's reader.
Expand Down
126 changes: 126 additions & 0 deletions api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@ package api

import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"time"

"github.com/hashicorp/nomad/nomad/structs"
"github.com/hashicorp/nomad/testutil"
"github.com/stretchr/testify/assert"
)

type configCallback func(c *Config)
Expand Down Expand Up @@ -243,3 +246,126 @@ func TestQueryString(t *testing.T) {
t.Fatalf("bad uri: %q", uri)
}
}

func TestClient_NodeClient(t *testing.T) {
http := "testdomain:4646"
tlsNode := func(string, *QueryOptions) (*Node, *QueryMeta, error) {
return &Node{
ID: structs.GenerateUUID(),
Status: "ready",
HTTPAddr: http,
TLSEnabled: true,
}, nil, nil
}
noTlsNode := func(string, *QueryOptions) (*Node, *QueryMeta, error) {
return &Node{
ID: structs.GenerateUUID(),
Status: "ready",
HTTPAddr: http,
TLSEnabled: false,
}, nil, nil
}

optionNoRegion := &QueryOptions{}
optionRegion := &QueryOptions{
Region: "foo",
}

clientNoRegion, err := NewClient(DefaultConfig())
assert.Nil(t, err)

regionConfig := DefaultConfig()
regionConfig.Region = "bar"
clientRegion, err := NewClient(regionConfig)
assert.Nil(t, err)

expectedTLSAddr := fmt.Sprintf("https://%s", http)
expectedNoTLSAddr := fmt.Sprintf("http://%s", http)

cases := []struct {
Node nodeLookup
QueryOptions *QueryOptions
Client *Client
ExpectedAddr string
ExpectedRegion string
ExpectedTLSServerName string
}{
{
Node: tlsNode,
QueryOptions: optionNoRegion,
Client: clientNoRegion,
ExpectedAddr: expectedTLSAddr,
ExpectedRegion: "global",
ExpectedTLSServerName: "client.global.nomad",
},
{
Node: tlsNode,
QueryOptions: optionRegion,
Client: clientNoRegion,
ExpectedAddr: expectedTLSAddr,
ExpectedRegion: "foo",
ExpectedTLSServerName: "client.foo.nomad",
},
{
Node: tlsNode,
QueryOptions: optionRegion,
Client: clientRegion,
ExpectedAddr: expectedTLSAddr,
ExpectedRegion: "foo",
ExpectedTLSServerName: "client.foo.nomad",
},
{
Node: tlsNode,
QueryOptions: optionNoRegion,
Client: clientRegion,
ExpectedAddr: expectedTLSAddr,
ExpectedRegion: "bar",
ExpectedTLSServerName: "client.bar.nomad",
},
{
Node: noTlsNode,
QueryOptions: optionNoRegion,
Client: clientNoRegion,
ExpectedAddr: expectedNoTLSAddr,
ExpectedRegion: "global",
ExpectedTLSServerName: "",
},
{
Node: noTlsNode,
QueryOptions: optionRegion,
Client: clientNoRegion,
ExpectedAddr: expectedNoTLSAddr,
ExpectedRegion: "foo",
ExpectedTLSServerName: "",
},
{
Node: noTlsNode,
QueryOptions: optionRegion,
Client: clientRegion,
ExpectedAddr: expectedNoTLSAddr,
ExpectedRegion: "foo",
ExpectedTLSServerName: "",
},
{
Node: noTlsNode,
QueryOptions: optionNoRegion,
Client: clientRegion,
ExpectedAddr: expectedNoTLSAddr,
ExpectedRegion: "bar",
ExpectedTLSServerName: "",
},
}

for _, c := range cases {
name := fmt.Sprintf("%s__%s__%s", c.ExpectedAddr, c.ExpectedRegion, c.ExpectedTLSServerName)
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
nodeClient, err := c.Client.getNodeClientImpl("testID", c.QueryOptions, c.Node)
assert.Nil(err)
assert.Equal(c.ExpectedRegion, nodeClient.config.Region)
assert.Equal(c.ExpectedAddr, nodeClient.config.Address)
assert.NotNil(nodeClient.config.TLSConfig)
assert.Equal(c.ExpectedTLSServerName, nodeClient.config.TLSConfig.TLSServerName)
})
}
}
6 changes: 5 additions & 1 deletion command/alloc_status.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,11 @@ func (c *AllocStatusCommand) AutocompleteFlags() complete.Flags {

func (c *AllocStatusCommand) AutocompleteArgs() complete.Predictor {
return complete.PredictFunc(func(a complete.Args) []string {
client, _ := c.Meta.Client()
client, err := c.Meta.Client()
if err != nil {
return nil
}

resp, _, err := client.Search().PrefixSearch(a.Last, contexts.Allocs, nil)
if err != nil {
return []string{}
Expand Down
6 changes: 5 additions & 1 deletion command/deployment_fail.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,11 @@ func (c *DeploymentFailCommand) AutocompleteFlags() complete.Flags {

func (c *DeploymentFailCommand) AutocompleteArgs() complete.Predictor {
return complete.PredictFunc(func(a complete.Args) []string {
client, _ := c.Meta.Client()
client, err := c.Meta.Client()
if err != nil {
return nil
}

resp, _, err := client.Search().PrefixSearch(a.Last, contexts.Deployments, nil)
if err != nil {
return []string{}
Expand Down
6 changes: 5 additions & 1 deletion command/deployment_pause.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@ func (c *DeploymentPauseCommand) AutocompleteFlags() complete.Flags {

func (c *DeploymentPauseCommand) AutocompleteArgs() complete.Predictor {
return complete.PredictFunc(func(a complete.Args) []string {
client, _ := c.Meta.Client()
client, err := c.Meta.Client()
if err != nil {
return nil
}

resp, _, err := client.Search().PrefixSearch(a.Last, contexts.Deployments, nil)
if err != nil {
return []string{}
Expand Down
6 changes: 5 additions & 1 deletion command/deployment_promote.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,11 @@ func (c *DeploymentPromoteCommand) AutocompleteFlags() complete.Flags {

func (c *DeploymentPromoteCommand) AutocompleteArgs() complete.Predictor {
return complete.PredictFunc(func(a complete.Args) []string {
client, _ := c.Meta.Client()
client, err := c.Meta.Client()
if err != nil {
return nil
}

resp, _, err := client.Search().PrefixSearch(a.Last, contexts.Deployments, nil)
if err != nil {
return []string{}
Expand Down
6 changes: 5 additions & 1 deletion command/deployment_resume.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,11 @@ func (c *DeploymentResumeCommand) AutocompleteFlags() complete.Flags {

func (c *DeploymentResumeCommand) AutocompleteArgs() complete.Predictor {
return complete.PredictFunc(func(a complete.Args) []string {
client, _ := c.Meta.Client()
client, err := c.Meta.Client()
if err != nil {
return nil
}

resp, _, err := client.Search().PrefixSearch(a.Last, contexts.Deployments, nil)
if err != nil {
return []string{}
Expand Down
6 changes: 5 additions & 1 deletion command/deployment_status.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,11 @@ func (c *DeploymentStatusCommand) AutocompleteFlags() complete.Flags {

func (c *DeploymentStatusCommand) AutocompleteArgs() complete.Predictor {
return complete.PredictFunc(func(a complete.Args) []string {
client, _ := c.Meta.Client()
client, err := c.Meta.Client()
if err != nil {
return nil
}

resp, _, err := client.Search().PrefixSearch(a.Last, contexts.Deployments, nil)
if err != nil {
return []string{}
Expand Down
10 changes: 9 additions & 1 deletion command/eval_status.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,15 @@ func (c *EvalStatusCommand) AutocompleteFlags() complete.Flags {

func (c *EvalStatusCommand) AutocompleteArgs() complete.Predictor {
return complete.PredictFunc(func(a complete.Args) []string {
client, _ := c.Meta.Client()
client, err := c.Meta.Client()
if err != nil {
return nil
}

if err != nil {
return nil
}

resp, _, err := client.Search().PrefixSearch(a.Last, contexts.Evals, nil)
if err != nil {
return []string{}
Expand Down
6 changes: 5 additions & 1 deletion command/fs.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,11 @@ func (c *FSCommand) AutocompleteFlags() complete.Flags {

func (f *FSCommand) AutocompleteArgs() complete.Predictor {
return complete.PredictFunc(func(a complete.Args) []string {
client, _ := f.Meta.Client()
client, err := f.Meta.Client()
if err != nil {
return nil
}

resp, _, err := client.Search().PrefixSearch(a.Last, contexts.Allocs, nil)
if err != nil {
return []string{}
Expand Down
6 changes: 5 additions & 1 deletion command/inspect.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,11 @@ func (c *InspectCommand) AutocompleteFlags() complete.Flags {

func (c *InspectCommand) AutocompleteArgs() complete.Predictor {
return complete.PredictFunc(func(a complete.Args) []string {
client, _ := c.Meta.Client()
client, err := c.Meta.Client()
if err != nil {
return nil
}

resp, _, err := client.Search().PrefixSearch(a.Last, contexts.Jobs, nil)
if err != nil {
return []string{}
Expand Down
Loading