Skip to content

Commit

Permalink
Merge pull request #3127 from hashicorp/b-tls-api
Browse files Browse the repository at this point in the history
Fix TLSServerName for Node API Client
  • Loading branch information
dadgar authored Aug 29, 2017
2 parents 75221f9 + 40a8efb commit c85e8aa
Show file tree
Hide file tree
Showing 40 changed files with 479 additions and 35 deletions.
54 changes: 39 additions & 15 deletions api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,8 @@ type Config struct {
// Region to use. If not provided, the default agent region is used.
Region string

// HttpClient is the client to use. Default will be
// used if not provided.
HttpClient *http.Client
// httpClient is the client to use. Default will be used if not provided.
httpClient *http.Client

// HttpAuth is the auth info to use for http access.
HttpAuth *HttpBasicAuth
Expand All @@ -117,15 +116,18 @@ func (c *Config) ClientConfig(region, address string, tlsEnabled bool) *Config {
if tlsEnabled {
scheme = "https"
}
defaultConfig := DefaultConfig()
config := &Config{
Address: fmt.Sprintf("%s://%s", scheme, address),
Region: region,
HttpClient: c.HttpClient,
httpClient: defaultConfig.httpClient,
HttpAuth: c.HttpAuth,
WaitTime: c.WaitTime,
TLSConfig: c.TLSConfig.Copy(),
}
config.TLSConfig.TLSServerName = fmt.Sprintf("client.%s.nomad", c.Region)
if tlsEnabled && config.TLSConfig != nil {
config.TLSConfig.TLSServerName = fmt.Sprintf("client.%s.nomad", region)
}

return config
}
Expand Down Expand Up @@ -169,10 +171,10 @@ func (t *TLSConfig) Copy() *TLSConfig {
func DefaultConfig() *Config {
config := &Config{
Address: "http://127.0.0.1:4646",
HttpClient: cleanhttp.DefaultClient(),
httpClient: cleanhttp.DefaultClient(),
TLSConfig: &TLSConfig{},
}
transport := config.HttpClient.Transport.(*http.Transport)
transport := config.httpClient.Transport.(*http.Transport)
transport.TLSHandshakeTimeout = 10 * time.Second
transport.TLSClientConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
Expand Down Expand Up @@ -221,7 +223,10 @@ func DefaultConfig() *Config {

// ConfigureTLS applies a set of TLS configurations to the the HTTP client.
func (c *Config) ConfigureTLS() error {
if c.HttpClient == nil {
if c.TLSConfig == nil {
return nil
}
if c.httpClient == nil {
return fmt.Errorf("config HTTP Client must be set")
}

Expand All @@ -240,7 +245,7 @@ func (c *Config) ConfigureTLS() error {
}
}

clientTLSConfig := c.HttpClient.Transport.(*http.Transport).TLSClientConfig
clientTLSConfig := c.httpClient.Transport.(*http.Transport).TLSClientConfig
rootConfig := &rootcerts.Config{
CAFile: c.TLSConfig.CACert,
CAPath: c.TLSConfig.CAPath,
Expand Down Expand Up @@ -277,8 +282,8 @@ func NewClient(config *Config) (*Client, error) {
return nil, fmt.Errorf("invalid address '%s': %v", config.Address, err)
}

if config.HttpClient == nil {
config.HttpClient = defConfig.HttpClient
if config.httpClient == nil {
config.httpClient = defConfig.httpClient
}

// Configure the TLS cofigurations
Expand All @@ -300,7 +305,18 @@ func (c *Client) SetRegion(region string) {
// GetNodeClient returns a new Client that will dial the specified node. If the
// QueryOptions is set, its region will be used.
func (c *Client) GetNodeClient(nodeID string, q *QueryOptions) (*Client, error) {
node, _, err := c.Nodes().Info(nodeID, q)
return c.getNodeClientImpl(nodeID, q, c.Nodes().Info)
}

// nodeLookup is the definition of a function used to lookup a node. This is
// largely used to mock the lookup in tests.
type nodeLookup func(nodeID string, q *QueryOptions) (*Node, *QueryMeta, error)

// getNodeClientImpl is the implementation of creating a API client for
// contacting a node. It takes a function to lookup the node such that it can be
// mocked during tests.
func (c *Client) getNodeClientImpl(nodeID string, q *QueryOptions, lookup nodeLookup) (*Client, error) {
node, _, err := lookup(nodeID, q)
if err != nil {
return nil, err
}
Expand All @@ -311,9 +327,17 @@ func (c *Client) GetNodeClient(nodeID string, q *QueryOptions) (*Client, error)
return nil, fmt.Errorf("http addr of node %q (%s) is not advertised", node.Name, nodeID)
}

region := c.config.Region
if q != nil && q.Region != "" {
var region string
switch {
case q != nil && q.Region != "":
// Prefer the region set in the query parameter
region = q.Region
case c.config.Region != "":
// If the client is configured for a particular region use that
region = c.config.Region
default:
// No region information is given so use the default.
region = "global"
}

// Get an API client for the node
Expand Down Expand Up @@ -471,7 +495,7 @@ func (c *Client) doRequest(r *request) (time.Duration, *http.Response, error) {
return 0, nil, err
}
start := time.Now()
resp, err := c.config.HttpClient.Do(req)
resp, err := c.config.httpClient.Do(req)
diff := time.Now().Sub(start)

// If the response is compressed, we swap the body's reader.
Expand Down
126 changes: 126 additions & 0 deletions api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@ package api

import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"time"

"github.com/hashicorp/nomad/nomad/structs"
"github.com/hashicorp/nomad/testutil"
"github.com/stretchr/testify/assert"
)

type configCallback func(c *Config)
Expand Down Expand Up @@ -243,3 +246,126 @@ func TestQueryString(t *testing.T) {
t.Fatalf("bad uri: %q", uri)
}
}

func TestClient_NodeClient(t *testing.T) {
http := "testdomain:4646"
tlsNode := func(string, *QueryOptions) (*Node, *QueryMeta, error) {
return &Node{
ID: structs.GenerateUUID(),
Status: "ready",
HTTPAddr: http,
TLSEnabled: true,
}, nil, nil
}
noTlsNode := func(string, *QueryOptions) (*Node, *QueryMeta, error) {
return &Node{
ID: structs.GenerateUUID(),
Status: "ready",
HTTPAddr: http,
TLSEnabled: false,
}, nil, nil
}

optionNoRegion := &QueryOptions{}
optionRegion := &QueryOptions{
Region: "foo",
}

clientNoRegion, err := NewClient(DefaultConfig())
assert.Nil(t, err)

regionConfig := DefaultConfig()
regionConfig.Region = "bar"
clientRegion, err := NewClient(regionConfig)
assert.Nil(t, err)

expectedTLSAddr := fmt.Sprintf("https://%s", http)
expectedNoTLSAddr := fmt.Sprintf("http://%s", http)

cases := []struct {
Node nodeLookup
QueryOptions *QueryOptions
Client *Client
ExpectedAddr string
ExpectedRegion string
ExpectedTLSServerName string
}{
{
Node: tlsNode,
QueryOptions: optionNoRegion,
Client: clientNoRegion,
ExpectedAddr: expectedTLSAddr,
ExpectedRegion: "global",
ExpectedTLSServerName: "client.global.nomad",
},
{
Node: tlsNode,
QueryOptions: optionRegion,
Client: clientNoRegion,
ExpectedAddr: expectedTLSAddr,
ExpectedRegion: "foo",
ExpectedTLSServerName: "client.foo.nomad",
},
{
Node: tlsNode,
QueryOptions: optionRegion,
Client: clientRegion,
ExpectedAddr: expectedTLSAddr,
ExpectedRegion: "foo",
ExpectedTLSServerName: "client.foo.nomad",
},
{
Node: tlsNode,
QueryOptions: optionNoRegion,
Client: clientRegion,
ExpectedAddr: expectedTLSAddr,
ExpectedRegion: "bar",
ExpectedTLSServerName: "client.bar.nomad",
},
{
Node: noTlsNode,
QueryOptions: optionNoRegion,
Client: clientNoRegion,
ExpectedAddr: expectedNoTLSAddr,
ExpectedRegion: "global",
ExpectedTLSServerName: "",
},
{
Node: noTlsNode,
QueryOptions: optionRegion,
Client: clientNoRegion,
ExpectedAddr: expectedNoTLSAddr,
ExpectedRegion: "foo",
ExpectedTLSServerName: "",
},
{
Node: noTlsNode,
QueryOptions: optionRegion,
Client: clientRegion,
ExpectedAddr: expectedNoTLSAddr,
ExpectedRegion: "foo",
ExpectedTLSServerName: "",
},
{
Node: noTlsNode,
QueryOptions: optionNoRegion,
Client: clientRegion,
ExpectedAddr: expectedNoTLSAddr,
ExpectedRegion: "bar",
ExpectedTLSServerName: "",
},
}

for _, c := range cases {
name := fmt.Sprintf("%s__%s__%s", c.ExpectedAddr, c.ExpectedRegion, c.ExpectedTLSServerName)
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
nodeClient, err := c.Client.getNodeClientImpl("testID", c.QueryOptions, c.Node)
assert.Nil(err)
assert.Equal(c.ExpectedRegion, nodeClient.config.Region)
assert.Equal(c.ExpectedAddr, nodeClient.config.Address)
assert.NotNil(nodeClient.config.TLSConfig)
assert.Equal(c.ExpectedTLSServerName, nodeClient.config.TLSConfig.TLSServerName)
})
}
}
6 changes: 5 additions & 1 deletion command/alloc_status.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,11 @@ func (c *AllocStatusCommand) AutocompleteFlags() complete.Flags {

func (c *AllocStatusCommand) AutocompleteArgs() complete.Predictor {
return complete.PredictFunc(func(a complete.Args) []string {
client, _ := c.Meta.Client()
client, err := c.Meta.Client()
if err != nil {
return nil
}

resp, _, err := client.Search().PrefixSearch(a.Last, contexts.Allocs, nil)
if err != nil {
return []string{}
Expand Down
6 changes: 5 additions & 1 deletion command/deployment_fail.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,11 @@ func (c *DeploymentFailCommand) AutocompleteFlags() complete.Flags {

func (c *DeploymentFailCommand) AutocompleteArgs() complete.Predictor {
return complete.PredictFunc(func(a complete.Args) []string {
client, _ := c.Meta.Client()
client, err := c.Meta.Client()
if err != nil {
return nil
}

resp, _, err := client.Search().PrefixSearch(a.Last, contexts.Deployments, nil)
if err != nil {
return []string{}
Expand Down
6 changes: 5 additions & 1 deletion command/deployment_pause.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@ func (c *DeploymentPauseCommand) AutocompleteFlags() complete.Flags {

func (c *DeploymentPauseCommand) AutocompleteArgs() complete.Predictor {
return complete.PredictFunc(func(a complete.Args) []string {
client, _ := c.Meta.Client()
client, err := c.Meta.Client()
if err != nil {
return nil
}

resp, _, err := client.Search().PrefixSearch(a.Last, contexts.Deployments, nil)
if err != nil {
return []string{}
Expand Down
6 changes: 5 additions & 1 deletion command/deployment_promote.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,11 @@ func (c *DeploymentPromoteCommand) AutocompleteFlags() complete.Flags {

func (c *DeploymentPromoteCommand) AutocompleteArgs() complete.Predictor {
return complete.PredictFunc(func(a complete.Args) []string {
client, _ := c.Meta.Client()
client, err := c.Meta.Client()
if err != nil {
return nil
}

resp, _, err := client.Search().PrefixSearch(a.Last, contexts.Deployments, nil)
if err != nil {
return []string{}
Expand Down
6 changes: 5 additions & 1 deletion command/deployment_resume.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,11 @@ func (c *DeploymentResumeCommand) AutocompleteFlags() complete.Flags {

func (c *DeploymentResumeCommand) AutocompleteArgs() complete.Predictor {
return complete.PredictFunc(func(a complete.Args) []string {
client, _ := c.Meta.Client()
client, err := c.Meta.Client()
if err != nil {
return nil
}

resp, _, err := client.Search().PrefixSearch(a.Last, contexts.Deployments, nil)
if err != nil {
return []string{}
Expand Down
6 changes: 5 additions & 1 deletion command/deployment_status.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,11 @@ func (c *DeploymentStatusCommand) AutocompleteFlags() complete.Flags {

func (c *DeploymentStatusCommand) AutocompleteArgs() complete.Predictor {
return complete.PredictFunc(func(a complete.Args) []string {
client, _ := c.Meta.Client()
client, err := c.Meta.Client()
if err != nil {
return nil
}

resp, _, err := client.Search().PrefixSearch(a.Last, contexts.Deployments, nil)
if err != nil {
return []string{}
Expand Down
10 changes: 9 additions & 1 deletion command/eval_status.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,15 @@ func (c *EvalStatusCommand) AutocompleteFlags() complete.Flags {

func (c *EvalStatusCommand) AutocompleteArgs() complete.Predictor {
return complete.PredictFunc(func(a complete.Args) []string {
client, _ := c.Meta.Client()
client, err := c.Meta.Client()
if err != nil {
return nil
}

if err != nil {
return nil
}

resp, _, err := client.Search().PrefixSearch(a.Last, contexts.Evals, nil)
if err != nil {
return []string{}
Expand Down
6 changes: 5 additions & 1 deletion command/fs.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,11 @@ func (c *FSCommand) AutocompleteFlags() complete.Flags {

func (f *FSCommand) AutocompleteArgs() complete.Predictor {
return complete.PredictFunc(func(a complete.Args) []string {
client, _ := f.Meta.Client()
client, err := f.Meta.Client()
if err != nil {
return nil
}

resp, _, err := client.Search().PrefixSearch(a.Last, contexts.Allocs, nil)
if err != nil {
return []string{}
Expand Down
6 changes: 5 additions & 1 deletion command/inspect.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,11 @@ func (c *InspectCommand) AutocompleteFlags() complete.Flags {

func (c *InspectCommand) AutocompleteArgs() complete.Predictor {
return complete.PredictFunc(func(a complete.Args) []string {
client, _ := c.Meta.Client()
client, err := c.Meta.Client()
if err != nil {
return nil
}

resp, _, err := client.Search().PrefixSearch(a.Last, contexts.Jobs, nil)
if err != nil {
return []string{}
Expand Down
Loading

0 comments on commit c85e8aa

Please sign in to comment.