Skip to content

Commit

Permalink
Fix TLSServerName for Node API Client
Browse files Browse the repository at this point in the history
This PR fixes the construction of the TLSServerName when connecting to a
node that has TLS enabled and adds tests for all possible permutations.

Fixes #3013
  • Loading branch information
dadgar committed Aug 29, 2017
1 parent 6fb08b8 commit 4d3b75d
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 2 deletions.
23 changes: 21 additions & 2 deletions api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,9 @@ func (c *Config) ClientConfig(region, address string, tlsEnabled bool) *Config {
WaitTime: c.WaitTime,
TLSConfig: c.TLSConfig.Copy(),
}
config.TLSConfig.TLSServerName = fmt.Sprintf("client.%s.nomad", c.Region)
if tlsEnabled && config.TLSConfig != nil {
config.TLSConfig.TLSServerName = fmt.Sprintf("client.%s.nomad", region)
}

return config
}
Expand Down Expand Up @@ -221,6 +223,9 @@ func DefaultConfig() *Config {

// ConfigureTLS applies a set of TLS configurations to the the HTTP client.
func (c *Config) ConfigureTLS() error {
if c.TLSConfig == nil {
return nil
}
if c.HttpClient == nil {
return fmt.Errorf("config HTTP Client must be set")
}
Expand Down Expand Up @@ -300,7 +305,17 @@ func (c *Client) SetRegion(region string) {
// GetNodeClient returns a new Client that will dial the specified node. If the
// QueryOptions is set, its region will be used.
func (c *Client) GetNodeClient(nodeID string, q *QueryOptions) (*Client, error) {
node, _, err := c.Nodes().Info(nodeID, q)
return c.getNodeClientImpl(nodeID, q, c.Nodes().Info)
}

// nodeLookup is used to lookup a node
type nodeLookup func(nodeID string, q *QueryOptions) (*Node, *QueryMeta, error)

// getNodeClientImpl is the implementation of creating a API client for
// contacting a node. It is takes a function to lookup the node such that it can
// be mocked during tests.
func (c *Client) getNodeClientImpl(nodeID string, q *QueryOptions, lookup nodeLookup) (*Client, error) {
node, _, err := lookup(nodeID, q)
if err != nil {
return nil, err
}
Expand All @@ -316,6 +331,10 @@ func (c *Client) GetNodeClient(nodeID string, q *QueryOptions) (*Client, error)
region = q.Region
}

if region == "" {
region = "global"
}

// Get an API client for the node
conf := c.config.ClientConfig(region, node.HTTPAddr, node.TLSEnabled)
return NewClient(conf)
Expand Down
126 changes: 126 additions & 0 deletions api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@ package api

import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"time"

"github.com/hashicorp/nomad/nomad/structs"
"github.com/hashicorp/nomad/testutil"
"github.com/stretchr/testify/assert"
)

type configCallback func(c *Config)
Expand Down Expand Up @@ -243,3 +246,126 @@ func TestQueryString(t *testing.T) {
t.Fatalf("bad uri: %q", uri)
}
}

func TestClient_NodeClient(t *testing.T) {
http := "testdomain:4646"
tlsNode := func(string, *QueryOptions) (*Node, *QueryMeta, error) {
return &Node{
ID: structs.GenerateUUID(),
Status: "ready",
HTTPAddr: http,
TLSEnabled: true,
}, nil, nil
}
noTlsNode := func(string, *QueryOptions) (*Node, *QueryMeta, error) {
return &Node{
ID: structs.GenerateUUID(),
Status: "ready",
HTTPAddr: http,
TLSEnabled: false,
}, nil, nil
}

optionNoRegion := &QueryOptions{}
optionRegion := &QueryOptions{
Region: "foo",
}

clientNoRegion, err := NewClient(DefaultConfig())
assert.Nil(t, err)

regionConfig := DefaultConfig()
regionConfig.Region = "bar"
clientRegion, err := NewClient(regionConfig)
assert.Nil(t, err)

expectedTLSAddr := fmt.Sprintf("https://%s", http)
expectedNoTLSAddr := fmt.Sprintf("http://%s", http)

cases := []struct {
Node nodeLookup
QueryOptions *QueryOptions
Client *Client
ExpectedAddr string
ExpectedRegion string
ExpectedTLSServerName string
}{
{
Node: tlsNode,
QueryOptions: optionNoRegion,
Client: clientNoRegion,
ExpectedAddr: expectedTLSAddr,
ExpectedRegion: "global",
ExpectedTLSServerName: "client.global.nomad",
},
{
Node: tlsNode,
QueryOptions: optionRegion,
Client: clientNoRegion,
ExpectedAddr: expectedTLSAddr,
ExpectedRegion: "foo",
ExpectedTLSServerName: "client.foo.nomad",
},
{
Node: tlsNode,
QueryOptions: optionRegion,
Client: clientRegion,
ExpectedAddr: expectedTLSAddr,
ExpectedRegion: "foo",
ExpectedTLSServerName: "client.foo.nomad",
},
{
Node: tlsNode,
QueryOptions: optionNoRegion,
Client: clientRegion,
ExpectedAddr: expectedTLSAddr,
ExpectedRegion: "bar",
ExpectedTLSServerName: "client.bar.nomad",
},
{
Node: noTlsNode,
QueryOptions: optionNoRegion,
Client: clientNoRegion,
ExpectedAddr: expectedNoTLSAddr,
ExpectedRegion: "global",
ExpectedTLSServerName: "",
},
{
Node: noTlsNode,
QueryOptions: optionRegion,
Client: clientNoRegion,
ExpectedAddr: expectedNoTLSAddr,
ExpectedRegion: "foo",
ExpectedTLSServerName: "",
},
{
Node: noTlsNode,
QueryOptions: optionRegion,
Client: clientRegion,
ExpectedAddr: expectedNoTLSAddr,
ExpectedRegion: "foo",
ExpectedTLSServerName: "",
},
{
Node: noTlsNode,
QueryOptions: optionNoRegion,
Client: clientRegion,
ExpectedAddr: expectedNoTLSAddr,
ExpectedRegion: "bar",
ExpectedTLSServerName: "",
},
}

for _, c := range cases {
name := fmt.Sprintf("%s__%s__%s", c.ExpectedAddr, c.ExpectedRegion, c.ExpectedTLSServerName)
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
nodeClient, err := c.Client.getNodeClientImpl("testID", c.QueryOptions, c.Node)
assert.Nil(err)
assert.Equal(c.ExpectedRegion, nodeClient.config.Region)
assert.Equal(c.ExpectedAddr, nodeClient.config.Address)
assert.NotNil(nodeClient.config.TLSConfig)
assert.Equal(c.ExpectedTLSServerName, nodeClient.config.TLSConfig.TLSServerName)
})
}
}

0 comments on commit 4d3b75d

Please sign in to comment.