diff --git a/nomad/endpoints_oss.go b/nomad/endpoints_oss.go index 3d59b57ead0..006b05552c5 100644 --- a/nomad/endpoints_oss.go +++ b/nomad/endpoints_oss.go @@ -2,6 +2,8 @@ package nomad +import "net/rpc" + // EnterpriseEndpoints holds the set of enterprise only endpoints to register type EnterpriseEndpoints struct{} @@ -12,4 +14,4 @@ func NewEnterpriseEndpoints(s *Server) *EnterpriseEndpoints { } // Register is a no-op in oss. -func (e *EnterpriseEndpoints) Register(s *Server) {} +func (e *EnterpriseEndpoints) Register(s *rpc.Server) {} diff --git a/nomad/heartbeat.go b/nomad/heartbeat.go index 89bc8601015..54e885337cb 100644 --- a/nomad/heartbeat.go +++ b/nomad/heartbeat.go @@ -100,7 +100,7 @@ func (s *Server) invalidateHeartbeat(id string) { }, } var resp structs.NodeUpdateResponse - if err := s.endpoints.Node.UpdateStatus(&req, &resp); err != nil { + if err := s.staticEndpoints.Node.UpdateStatus(&req, &resp); err != nil { s.logger.Printf("[ERR] nomad.heartbeat: update status failed: %v", err) } } diff --git a/nomad/node_endpoint.go b/nomad/node_endpoint.go index 7f4265fb972..faa2b973da1 100644 --- a/nomad/node_endpoint.go +++ b/nomad/node_endpoint.go @@ -36,6 +36,9 @@ const ( type Node struct { srv *Server + // ctx provides context regarding the underlying connection + ctx *RPCContext + // updates holds pending client status updates for allocations updates []*structs.Allocation @@ -114,6 +117,13 @@ func (n *Node) Register(args *structs.NodeRegisterRequest, reply *structs.NodeUp } } + // We have a valid node connection, so add the mapping to cache the + // connection and allow the server to send RPCs to the client. + if n.ctx != nil && n.ctx.NodeID == "" { + n.ctx.NodeID = args.Node.ID + n.srv.addNodeConn(n.ctx) + } + // Commit this update via Raft _, index, err := n.srv.raftApply(structs.NodeRegisterRequestType, args) if err != nil { @@ -305,6 +315,13 @@ func (n *Node) UpdateStatus(args *structs.NodeUpdateStatusRequest, reply *struct return fmt.Errorf("node not found") } + // We have a valid node connection, so add the mapping to cache the + // connection and allow the server to send RPCs to the client. + if n.ctx != nil && n.ctx.NodeID == "" { + n.ctx.NodeID = args.NodeID + n.srv.addNodeConn(n.ctx) + } + // XXX: Could use the SecretID here but have to update the heartbeat system // to track SecretIDs. @@ -724,6 +741,13 @@ func (n *Node) GetClientAllocs(args *structs.NodeSpecificRequest, return fmt.Errorf("node secret ID does not match") } + // We have a valid node connection, so add the mapping to cache the + // connection and allow the server to send RPCs to the client. + if n.ctx != nil && n.ctx.NodeID == "" { + n.ctx.NodeID = args.NodeID + n.srv.addNodeConn(n.ctx) + } + var err error allocs, err = state.AllocsByNode(ws, args.NodeID) if err != nil { diff --git a/nomad/node_endpoint_test.go b/nomad/node_endpoint_test.go index 909e2a63715..d3121a77f9d 100644 --- a/nomad/node_endpoint_test.go +++ b/nomad/node_endpoint_test.go @@ -16,15 +16,20 @@ import ( "github.com/hashicorp/nomad/testutil" vapi "github.com/hashicorp/vault/api" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestClientEndpoint_Register(t *testing.T) { t.Parallel() + require := require.New(t) s1 := testServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + // Check that we have no client connections + require.Empty(s1.connectedNodes()) + // Create the register request node := mock.Node() req := &structs.NodeRegisterRequest{ @@ -41,6 +46,11 @@ func TestClientEndpoint_Register(t *testing.T) { t.Fatalf("bad index: %d", resp.Index) } + // Check that we have the client connections + nodes := s1.connectedNodes() + require.Len(nodes, 1) + require.Equal(node.ID, nodes[0]) + // Check for the node in the FSM state := s1.fsm.State() ws := memdb.NewWatchSet() @@ -57,6 +67,15 @@ func TestClientEndpoint_Register(t *testing.T) { if out.ComputedClass == "" { t.Fatal("ComputedClass not set") } + + // Close the connection and check that we remove the client connections + require.Nil(codec.Close()) + testutil.WaitForResult(func() (bool, error) { + nodes := s1.connectedNodes() + return len(nodes) == 0, nil + }, func(err error) { + t.Fatalf("should have no clients") + }) } func TestClientEndpoint_Register_SecretMismatch(t *testing.T) { @@ -260,11 +279,15 @@ func TestClientEndpoint_Deregister_Vault(t *testing.T) { func TestClientEndpoint_UpdateStatus(t *testing.T) { t.Parallel() + require := require.New(t) s1 := testServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + // Check that we have no client connections + require.Empty(s1.connectedNodes()) + // Create the register request node := mock.Node() reg := &structs.NodeRegisterRequest{ @@ -304,6 +327,11 @@ func TestClientEndpoint_UpdateStatus(t *testing.T) { t.Fatalf("bad: %#v", ttl) } + // Check that we have the client connections + nodes := s1.connectedNodes() + require.Len(nodes, 1) + require.Equal(node.ID, nodes[0]) + // Check for the node in the FSM state := s1.fsm.State() ws := memdb.NewWatchSet() @@ -317,6 +345,15 @@ func TestClientEndpoint_UpdateStatus(t *testing.T) { if out.ModifyIndex != resp2.Index { t.Fatalf("index mis-match") } + + // Close the connection and check that we remove the client connections + require.Nil(codec.Close()) + testutil.WaitForResult(func() (bool, error) { + nodes := s1.connectedNodes() + return len(nodes) == 0, nil + }, func(err error) { + t.Fatalf("should have no clients") + }) } func TestClientEndpoint_UpdateStatus_Vault(t *testing.T) { @@ -1230,30 +1267,23 @@ func TestClientEndpoint_GetAllocs_ACL_Basic(t *testing.T) { func TestClientEndpoint_GetClientAllocs(t *testing.T) { t.Parallel() + require := require.New(t) s1 := testServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + // Check that we have no client connections + require.Empty(s1.connectedNodes()) + // Create the register request node := mock.Node() - reg := &structs.NodeRegisterRequest{ - Node: node, - WriteRequest: structs.WriteRequest{Region: "global"}, - } - - // Fetch the response - var resp structs.GenericResponse - if err := msgpackrpc.CallWithCodec(codec, "Node.Register", reg, &resp); err != nil { - t.Fatalf("err: %v", err) - } - node.CreateIndex = resp.Index - node.ModifyIndex = resp.Index + state := s1.fsm.State() + require.Nil(state.UpsertNode(98, node)) // Inject fake evaluations alloc := mock.Alloc() alloc.NodeID = node.ID - state := s1.fsm.State() state.UpsertJobSummary(99, mock.JobSummary(alloc.JobID)) err := state.UpsertAllocs(100, []*structs.Allocation{alloc}) if err != nil { @@ -1278,6 +1308,11 @@ func TestClientEndpoint_GetClientAllocs(t *testing.T) { t.Fatalf("bad: %#v", resp2.Allocs) } + // Check that we have the client connections + nodes := s1.connectedNodes() + require.Len(nodes, 1) + require.Equal(node.ID, nodes[0]) + // Lookup node with bad SecretID get.SecretID = "foobarbaz" var resp3 structs.NodeClientAllocsResponse @@ -1298,6 +1333,15 @@ func TestClientEndpoint_GetClientAllocs(t *testing.T) { if len(resp4.Allocs) != 0 { t.Fatalf("unexpected node %#v", resp3.Allocs) } + + // Close the connection and check that we remove the client connections + require.Nil(codec.Close()) + testutil.WaitForResult(func() (bool, error) { + nodes := s1.connectedNodes() + return len(nodes) == 0, nil + }, func(err error) { + t.Fatalf("should have no clients") + }) } func TestClientEndpoint_GetClientAllocs_Blocking(t *testing.T) { @@ -1746,7 +1790,7 @@ func TestClientEndpoint_BatchUpdate(t *testing.T) { // Call to do the batch update bf := NewBatchFuture() - endpoint := s1.endpoints.Node + endpoint := s1.staticEndpoints.Node endpoint.batchUpdate(bf, []*structs.Allocation{clientAlloc}) if err := bf.Wait(); err != nil { t.Fatalf("err: %v", err) @@ -1864,7 +1908,7 @@ func TestClientEndpoint_CreateNodeEvals(t *testing.T) { } // Create some evaluations - ids, index, err := s1.endpoints.Node.createNodeEvals(alloc.NodeID, 1) + ids, index, err := s1.staticEndpoints.Node.createNodeEvals(alloc.NodeID, 1) if err != nil { t.Fatalf("err: %v", err) } diff --git a/nomad/rpc.go b/nomad/rpc.go index 828ee0c94c0..f765e288ca4 100644 --- a/nomad/rpc.go +++ b/nomad/rpc.go @@ -3,6 +3,7 @@ package nomad import ( "context" "crypto/tls" + "crypto/x509" "fmt" "io" "math/rand" @@ -55,6 +56,25 @@ const ( enqueueLimit = 30 * time.Second ) +// RPCContext provides metadata about the RPC connection. +type RPCContext struct { + // Conn exposes the raw connection. + Conn net.Conn + + // Session exposes the multiplexed connection session. + Session *yamux.Session + + // TLS marks whether the RPC is over a TLS based connection + TLS bool + + // VerifiedChains is is the Verified certificates presented by the incoming + // connection. + VerifiedChains [][]*x509.Certificate + + // NodeID marks the NodeID that initiated the connection. + NodeID string +} + // NewClientCodec returns a new rpc.ClientCodec to be used to make RPC calls to // the Nomad Server. func NewClientCodec(conn io.ReadWriteCloser) rpc.ClientCodec { @@ -80,14 +100,14 @@ func (s *Server) listen() { continue } - go s.handleConn(conn, false) + go s.handleConn(conn, &RPCContext{Conn: conn}) metrics.IncrCounter([]string{"nomad", "rpc", "accept_conn"}, 1) } } // handleConn is used to determine if this is a Raft or // Nomad type RPC connection and invoke the correct handler -func (s *Server) handleConn(conn net.Conn, isTLS bool) { +func (s *Server) handleConn(conn net.Conn, ctx *RPCContext) { // Read a single byte buf := make([]byte, 1) if _, err := conn.Read(buf); err != nil { @@ -99,7 +119,7 @@ func (s *Server) handleConn(conn net.Conn, isTLS bool) { } // Enforce TLS if EnableRPC is set - if s.config.TLSConfig.EnableRPC && !isTLS && RPCType(buf[0]) != rpcTLS { + if s.config.TLSConfig.EnableRPC && !ctx.TLS && RPCType(buf[0]) != rpcTLS { if !s.config.TLSConfig.RPCUpgradeMode { s.logger.Printf("[WARN] nomad.rpc: Non-TLS connection attempted from %s with RequireTLS set", conn.RemoteAddr().String()) conn.Close() @@ -110,14 +130,21 @@ func (s *Server) handleConn(conn net.Conn, isTLS bool) { // Switch on the byte switch RPCType(buf[0]) { case rpcNomad: - s.handleNomadConn(conn) + // Create an RPC Server and handle the request + server := rpc.NewServer() + s.setupRpcServer(server, ctx) + s.handleNomadConn(conn, server) + + // Remove any potential mapping between a NodeID to this connection and + // close the underlying connection. + s.removeNodeConn(ctx) case rpcRaft: metrics.IncrCounter([]string{"nomad", "rpc", "raft_handoff"}, 1) s.raftLayer.Handoff(conn) case rpcMultiplex: - s.handleMultiplex(conn) + s.handleMultiplex(conn, ctx) case rpcTLS: if s.rpcTLS == nil { @@ -126,7 +153,31 @@ func (s *Server) handleConn(conn net.Conn, isTLS bool) { return } conn = tls.Server(conn, s.rpcTLS) - s.handleConn(conn, true) + + // Force a handshake so we can get information about the TLS connection + // state. + tlsConn, ok := conn.(*tls.Conn) + if !ok { + s.logger.Printf("[ERR] nomad.rpc: expected TLS connection but got %T", conn) + conn.Close() + return + } + + if err := tlsConn.Handshake(); err != nil { + s.logger.Printf("[WARN] nomad.rpc: failed TLS handshake from connection from %v: %v", tlsConn.RemoteAddr(), err) + conn.Close() + return + } + + // Update the connection context with the fact that the connection is + // using TLS + ctx.TLS = true + + // Store the verified chains so they can be inspected later. + state := tlsConn.ConnectionState() + ctx.VerifiedChains = state.VerifiedChains + + s.handleConn(conn, ctx) default: s.logger.Printf("[ERR] nomad.rpc: unrecognized RPC byte: %v", buf[0]) @@ -137,11 +188,25 @@ func (s *Server) handleConn(conn net.Conn, isTLS bool) { // handleMultiplex is used to multiplex a single incoming connection // using the Yamux multiplexer -func (s *Server) handleMultiplex(conn net.Conn) { - defer conn.Close() +func (s *Server) handleMultiplex(conn net.Conn, ctx *RPCContext) { + defer func() { + // Remove any potential mapping between a NodeID to this connection and + // close the underlying connection. + s.removeNodeConn(ctx) + conn.Close() + }() + conf := yamux.DefaultConfig() conf.LogOutput = s.config.LogOutput server, _ := yamux.Server(conn, conf) + + // Update the context to store the yamux session + ctx.Session = server + + // Create the RPC server for this connection + rpcServer := rpc.NewServer() + s.setupRpcServer(rpcServer, ctx) + for { sub, err := server.Accept() if err != nil { @@ -150,12 +215,12 @@ func (s *Server) handleMultiplex(conn net.Conn) { } return } - go s.handleNomadConn(sub) + go s.handleNomadConn(sub, rpcServer) } } // handleNomadConn is used to service a single Nomad RPC connection -func (s *Server) handleNomadConn(conn net.Conn) { +func (s *Server) handleNomadConn(conn net.Conn, server *rpc.Server) { defer conn.Close() rpcCodec := NewServerCodec(conn) for { @@ -165,7 +230,7 @@ func (s *Server) handleNomadConn(conn net.Conn) { default: } - if err := s.rpcServer.ServeRequest(rpcCodec); err != nil { + if err := server.ServeRequest(rpcCodec); err != nil { if err != io.EOF && !strings.Contains(err.Error(), "closed") { s.logger.Printf("[ERR] nomad.rpc: RPC error: %v (%v)", err, conn) metrics.IncrCounter([]string{"nomad", "rpc", "request_error"}, 1) diff --git a/nomad/server.go b/nomad/server.go index 7648c74360f..a826e7601c6 100644 --- a/nomad/server.go +++ b/nomad/server.go @@ -29,6 +29,7 @@ import ( "github.com/hashicorp/raft" raftboltdb "github.com/hashicorp/raft-boltdb" "github.com/hashicorp/serf/serf" + "github.com/hashicorp/yamux" ) const ( @@ -88,9 +89,6 @@ type Server struct { // Connection pool to other Nomad servers connPool *ConnPool - // Endpoints holds our RPC endpoints - endpoints endpoints - // The raft instance is used among Nomad nodes within the // region to protect operations that require strong consistency leaderCh <-chan bool @@ -104,13 +102,26 @@ type Server struct { fsm *nomadFSM // rpcListener is used to listen for incoming connections - rpcListener net.Listener - rpcServer *rpc.Server + rpcListener net.Listener + + // rpcServer is the static RPC server that is used by the local agent. + rpcServer *rpc.Server + + // rpcAdvertise is the advertised address for the RPC listener. rpcAdvertise net.Addr // rpcTLS is the TLS config for incoming TLS requests rpcTLS *tls.Config + // staticEndpoints is the set of static endpoints that can be reused across + // all RPC connections + staticEndpoints endpoints + + // nodeConns is the set of multiplexed node connections we have keyed by + // NodeID + nodeConns map[string]*yamux.Session + nodeConnsLock sync.RWMutex + // peers is used to track the known Nomad servers. This is // used for region forwarding and clustering. peers map[string][]*serverParts @@ -256,6 +267,7 @@ func NewServer(config *Config, consulCatalog consul.CatalogAPI, logger *log.Logg connPool: NewPool(config.LogOutput, serverRPCCache, serverMaxStreams, tlsWrap), logger: logger, rpcServer: rpc.NewServer(), + nodeConns: make(map[string]*yamux.Session), peers: make(map[string][]*serverParts), localPeers: make(map[raft.ServerAddress]*serverParts), reconcileCh: make(chan serf.Member, 32), @@ -739,37 +751,8 @@ func (s *Server) setupVaultClient() error { // setupRPC is used to setup the RPC listener func (s *Server) setupRPC(tlsWrap tlsutil.RegionWrapper) error { - // Create endpoints - s.endpoints.ACL = &ACL{s} - s.endpoints.Alloc = &Alloc{s} - s.endpoints.Eval = &Eval{s} - s.endpoints.Job = &Job{s} - s.endpoints.Node = &Node{srv: s} - s.endpoints.Deployment = &Deployment{srv: s} - s.endpoints.Operator = &Operator{s} - s.endpoints.Periodic = &Periodic{s} - s.endpoints.Plan = &Plan{s} - s.endpoints.Region = &Region{s} - s.endpoints.Status = &Status{s} - s.endpoints.System = &System{s} - s.endpoints.Search = &Search{s} - s.endpoints.Enterprise = NewEnterpriseEndpoints(s) - - // Register the handlers - s.rpcServer.Register(s.endpoints.ACL) - s.rpcServer.Register(s.endpoints.Alloc) - s.rpcServer.Register(s.endpoints.Eval) - s.rpcServer.Register(s.endpoints.Job) - s.rpcServer.Register(s.endpoints.Node) - s.rpcServer.Register(s.endpoints.Deployment) - s.rpcServer.Register(s.endpoints.Operator) - s.rpcServer.Register(s.endpoints.Periodic) - s.rpcServer.Register(s.endpoints.Plan) - s.rpcServer.Register(s.endpoints.Region) - s.rpcServer.Register(s.endpoints.Status) - s.rpcServer.Register(s.endpoints.System) - s.rpcServer.Register(s.endpoints.Search) - s.endpoints.Enterprise.Register(s) + // Populate the static RPC server + s.setupRpcServer(s.rpcServer, nil) list, err := net.ListenTCP("tcp", s.config.RPCAddr) if err != nil { @@ -799,6 +782,49 @@ func (s *Server) setupRPC(tlsWrap tlsutil.RegionWrapper) error { return nil } +// setupRpcServer is used to populate an RPC server with endpoints +func (s *Server) setupRpcServer(server *rpc.Server, ctx *RPCContext) { + // Add the static endpoints to the RPC server. + if s.staticEndpoints.Status == nil { + // Initialize the list just once + s.staticEndpoints.ACL = &ACL{s} + s.staticEndpoints.Alloc = &Alloc{s} + s.staticEndpoints.Eval = &Eval{s} + s.staticEndpoints.Job = &Job{s} + s.staticEndpoints.Node = &Node{srv: s} // Add but don't register + s.staticEndpoints.Deployment = &Deployment{srv: s} + s.staticEndpoints.Operator = &Operator{s} + s.staticEndpoints.Periodic = &Periodic{s} + s.staticEndpoints.Plan = &Plan{s} + s.staticEndpoints.Region = &Region{s} + s.staticEndpoints.Status = &Status{s} + s.staticEndpoints.System = &System{s} + s.staticEndpoints.Search = &Search{s} + s.staticEndpoints.Enterprise = NewEnterpriseEndpoints(s) + } + + // Register the static handlers + server.Register(s.staticEndpoints.ACL) + server.Register(s.staticEndpoints.Alloc) + server.Register(s.staticEndpoints.Eval) + server.Register(s.staticEndpoints.Job) + server.Register(s.staticEndpoints.Deployment) + server.Register(s.staticEndpoints.Operator) + server.Register(s.staticEndpoints.Periodic) + server.Register(s.staticEndpoints.Plan) + server.Register(s.staticEndpoints.Region) + server.Register(s.staticEndpoints.Status) + server.Register(s.staticEndpoints.System) + server.Register(s.staticEndpoints.Search) + s.staticEndpoints.Enterprise.Register(server) + + // Create new dynamic endpoints and add them to the RPC server. + node := &Node{srv: s, ctx: ctx} + + // Register the dynamic endpoints + server.Register(node) +} + // setupRaft is used to setup and initialize Raft func (s *Server) setupRaft() error { // If we have an unclean exit then attempt to close the Raft store. @@ -1155,6 +1181,49 @@ func (s *Server) RPC(method string, args interface{}, reply interface{}) error { return codec.err } +// getNodeConn returns the connection to the given node and whether it exists. +func (s *Server) getNodeConn(nodeID string) (*yamux.Session, bool) { + s.nodeConnsLock.RLock() + defer s.nodeConnsLock.RUnlock() + session, ok := s.nodeConns[nodeID] + return session, ok +} + +// connectedNodes returns the set of nodes we have a connection with. +func (s *Server) connectedNodes() []string { + s.nodeConnsLock.RLock() + defer s.nodeConnsLock.RUnlock() + nodes := make([]string, 0, len(s.nodeConns)) + for nodeID := range s.nodeConns { + nodes = append(nodes, nodeID) + } + return nodes +} + +// addNodeConn adds the mapping between a node and its session. +func (s *Server) addNodeConn(ctx *RPCContext) { + // Hotpath the no-op + if ctx == nil || ctx.NodeID == "" { + return + } + + s.nodeConnsLock.Lock() + defer s.nodeConnsLock.Unlock() + s.nodeConns[ctx.NodeID] = ctx.Session +} + +// removeNodeConn removes the mapping between a node and its session. +func (s *Server) removeNodeConn(ctx *RPCContext) { + // Hotpath the no-op + if ctx == nil || ctx.NodeID == "" { + return + } + + s.nodeConnsLock.Lock() + defer s.nodeConnsLock.Unlock() + delete(s.nodeConns, ctx.NodeID) +} + // Stats is used to return statistics for debugging and insight // for various sub-systems func (s *Server) Stats() map[string]map[string]string { diff --git a/nomad/server_test.go b/nomad/server_test.go index 04175a2900a..cea9a0c3bdf 100644 --- a/nomad/server_test.go +++ b/nomad/server_test.go @@ -144,7 +144,7 @@ func TestServer_RPC(t *testing.T) { } } -func TestServer_RPC_MixedTLS(t *testing.T) { +func TestServer_RPC_TLS(t *testing.T) { t.Parallel() const ( cafile = "../helper/tlsutil/testdata/ca.pem" @@ -154,6 +154,7 @@ func TestServer_RPC_MixedTLS(t *testing.T) { dir := tmpDir(t) defer os.RemoveAll(dir) s1 := testServer(t, func(c *Config) { + c.Region = "regionFoo" c.BootstrapExpect = 3 c.DevMode = false c.DevDisableBootstrap = true @@ -170,53 +171,111 @@ func TestServer_RPC_MixedTLS(t *testing.T) { defer s1.Shutdown() s2 := testServer(t, func(c *Config) { + c.Region = "regionFoo" c.BootstrapExpect = 3 c.DevMode = false c.DevDisableBootstrap = true c.DataDir = path.Join(dir, "node2") + c.TLSConfig = &config.TLSConfig{ + EnableHTTP: true, + EnableRPC: true, + VerifyServerHostname: true, + CAFile: cafile, + CertFile: foocert, + KeyFile: fookey, + } }) defer s2.Shutdown() s3 := testServer(t, func(c *Config) { + c.Region = "regionFoo" c.BootstrapExpect = 3 c.DevMode = false c.DevDisableBootstrap = true c.DataDir = path.Join(dir, "node3") + c.TLSConfig = &config.TLSConfig{ + EnableHTTP: true, + EnableRPC: true, + VerifyServerHostname: true, + CAFile: cafile, + CertFile: foocert, + KeyFile: fookey, + } }) defer s3.Shutdown() testJoin(t, s1, s2, s3) + testutil.WaitForLeader(t, s1.RPC) - l1, l2, l3, shutdown := make(chan error, 1), make(chan error, 1), make(chan error, 1), make(chan struct{}, 1) + // Part of a server joining is making an RPC request, so just by testing + // that there is a leader we verify that the RPCs are working over TLS. +} - wait := func(done chan error, rpc func(string, interface{}, interface{}) error) { - for { - select { - case <-shutdown: - return - default: - } +func TestServer_RPC_MixedTLS(t *testing.T) { + t.Parallel() + const ( + cafile = "../helper/tlsutil/testdata/ca.pem" + foocert = "../helper/tlsutil/testdata/nomad-foo.pem" + fookey = "../helper/tlsutil/testdata/nomad-foo-key.pem" + ) + dir := tmpDir(t) + defer os.RemoveAll(dir) + s1 := testServer(t, func(c *Config) { + c.Region = "regionFoo" + c.BootstrapExpect = 3 + c.DevMode = false + c.DevDisableBootstrap = true + c.DataDir = path.Join(dir, "node1") + c.TLSConfig = &config.TLSConfig{ + EnableHTTP: true, + EnableRPC: true, + VerifyServerHostname: true, + CAFile: cafile, + CertFile: foocert, + KeyFile: fookey, + } + }) + defer s1.Shutdown() - args := &structs.GenericRequest{} - var leader string - err := rpc("Status.Leader", args, &leader) - if err != nil || leader != "" { - done <- err - } + s2 := testServer(t, func(c *Config) { + c.Region = "regionFoo" + c.BootstrapExpect = 3 + c.DevMode = false + c.DevDisableBootstrap = true + c.DataDir = path.Join(dir, "node2") + c.TLSConfig = &config.TLSConfig{ + EnableHTTP: true, + EnableRPC: true, + VerifyServerHostname: true, + CAFile: cafile, + CertFile: foocert, + KeyFile: fookey, + } + }) + defer s2.Shutdown() + s3 := testServer(t, func(c *Config) { + c.Region = "regionFoo" + c.BootstrapExpect = 3 + c.DevMode = false + c.DevDisableBootstrap = true + c.DataDir = path.Join(dir, "node3") + }) + defer s3.Shutdown() + + testJoin(t, s1, s2, s3) + + // Ensure that we do not form a quorum + start := time.Now() + for { + if time.Now().After(start.Add(2 * time.Second)) { + break } - } - go wait(l1, s1.RPC) - go wait(l2, s2.RPC) - go wait(l3, s3.RPC) - - select { - case <-time.After(5 * time.Second): - case err := <-l1: - t.Fatalf("Server 1 has leader or error: %v", err) - case err := <-l2: - t.Fatalf("Server 2 has leader or error: %v", err) - case err := <-l3: - t.Fatalf("Server 3 has leader or error: %v", err) + args := &structs.GenericRequest{} + var leader string + err := s1.RPC("Status.Leader", args, &leader) + if err == nil || leader != "" { + t.Fatalf("Got leader or no error: %q %v", leader, err) + } } }