diff --git a/nomad/client_rpc.go b/nomad/client_rpc.go index af59679b093..e1bbb6317da 100644 --- a/nomad/client_rpc.go +++ b/nomad/client_rpc.go @@ -30,7 +30,16 @@ type nodeConnState struct { func (s *Server) getNodeConn(nodeID string) (*nodeConnState, bool) { s.nodeConnsLock.RLock() defer s.nodeConnsLock.RUnlock() - state, ok := s.nodeConns[nodeID] + conns, ok := s.nodeConns[nodeID] + + // Return the latest conn + var state *nodeConnState + for _, conn := range conns { + if state == nil || state.Established.Before(conn.Established) { + state = conn + } + } + return state, ok } @@ -39,8 +48,12 @@ func (s *Server) connectedNodes() map[string]time.Time { s.nodeConnsLock.RLock() defer s.nodeConnsLock.RUnlock() nodes := make(map[string]time.Time, len(s.nodeConns)) - for nodeID, state := range s.nodeConns { - nodes[nodeID] = state.Established + for nodeID, conns := range s.nodeConns { + for _, conn := range conns { + if nodes[nodeID].Before(conn.Established) { + nodes[nodeID] = conn.Established + } + } } return nodes } @@ -54,11 +67,26 @@ func (s *Server) addNodeConn(ctx *RPCContext) { s.nodeConnsLock.Lock() defer s.nodeConnsLock.Unlock() - s.nodeConns[ctx.NodeID] = &nodeConnState{ + + // Capture the tracked connections so far + currentConns := s.nodeConns[ctx.NodeID] + + // Check if we already have the connection. If we do, just update the + // establish time. + for _, c := range currentConns { + if c.Ctx.Conn.LocalAddr().String() == ctx.Conn.LocalAddr().String() && + c.Ctx.Conn.RemoteAddr().String() == ctx.Conn.RemoteAddr().String() { + c.Established = time.Now() + return + } + } + + // Add the new conn + s.nodeConns[ctx.NodeID] = append(s.nodeConns[ctx.NodeID], &nodeConnState{ Session: ctx.Session, Established: time.Now(), Ctx: ctx, - } + }) } // removeNodeConn removes the mapping between a node and its session. @@ -70,7 +98,7 @@ func (s *Server) removeNodeConn(ctx *RPCContext) { s.nodeConnsLock.Lock() defer s.nodeConnsLock.Unlock() - state, ok := s.nodeConns[ctx.NodeID] + conns, ok := s.nodeConns[ctx.NodeID] if !ok { return } @@ -80,9 +108,12 @@ func (s *Server) removeNodeConn(ctx *RPCContext) { // dial various addresses that all route to the same server. The most common // case for this is the original address the client uses to connect to the // server differs from the advertised address sent by the heartbeat. - if state.Ctx.Conn.LocalAddr().String() == ctx.Conn.LocalAddr().String() && - state.Ctx.Conn.RemoteAddr().String() == ctx.Conn.RemoteAddr().String() { - delete(s.nodeConns, ctx.NodeID) + for i, conn := range conns { + if conn.Ctx.Conn.LocalAddr().String() == ctx.Conn.LocalAddr().String() && + conn.Ctx.Conn.RemoteAddr().String() == ctx.Conn.RemoteAddr().String() { + s.nodeConns[ctx.NodeID] = append(s.nodeConns[ctx.NodeID][:i], s.nodeConns[ctx.NodeID][i+1:]...) + return + } } } diff --git a/nomad/client_rpc_test.go b/nomad/client_rpc_test.go index c64eecec029..d7edad6bc51 100644 --- a/nomad/client_rpc_test.go +++ b/nomad/client_rpc_test.go @@ -57,6 +57,7 @@ func TestServer_removeNodeConn_differentAddrs(t *testing.T) { s1.addNodeConn(ctx1) s1.addNodeConn(ctx2) require.Len(s1.connectedNodes(), 1) + require.Len(s1.nodeConns[nodeID], 2) // Check that the value is the second conn. state, ok := s1.getNodeConn(nodeID) @@ -66,6 +67,7 @@ func TestServer_removeNodeConn_differentAddrs(t *testing.T) { // Delete the first s1.removeNodeConn(ctx1) require.Len(s1.connectedNodes(), 1) + require.Len(s1.nodeConns[nodeID], 1) // Check that the value is the second conn. state, ok = s1.getNodeConn(nodeID) diff --git a/nomad/server.go b/nomad/server.go index 5a214d71ea4..235e2988a3d 100644 --- a/nomad/server.go +++ b/nomad/server.go @@ -143,7 +143,7 @@ type Server struct { // nodeConns is the set of multiplexed node connections we have keyed by // NodeID - nodeConns map[string]*nodeConnState + nodeConns map[string][]*nodeConnState nodeConnsLock sync.RWMutex // peers is used to track the known Nomad servers. This is @@ -294,7 +294,7 @@ func NewServer(config *Config, consulCatalog consul.CatalogAPI, logger *log.Logg tlsWrap: tlsWrap, rpcServer: rpc.NewServer(), streamingRpcs: structs.NewStreamingRpcRegistry(), - nodeConns: make(map[string]*nodeConnState), + nodeConns: make(map[string][]*nodeConnState), peers: make(map[string][]*serverParts), localPeers: make(map[raft.ServerAddress]*serverParts), reconcileCh: make(chan serf.Member, 32),