diff --git a/command/agent/agent.go b/command/agent/agent.go index 7bc2aa76d11..2e643aa2e90 100644 --- a/command/agent/agent.go +++ b/command/agent/agent.go @@ -728,12 +728,14 @@ func (a *Agent) Stats() map[string]map[string]string { // ShouldReload determines if we should reload the configuration and agent // connections. If the TLS Configuration has not changed, we shouldn't reload. -func (a *Agent) ShouldReload(newConfig *Config) (bool, func(*Config) error) { +func (a *Agent) ShouldReload(newConfig *Config) bool { + a.configLock.Lock() + defer a.configLock.Unlock() if a.config.TLSConfig.Equals(newConfig.TLSConfig) { - return false, nil + return false } - return true, a.Reload + return true } // Reload handles configuration changes for the agent. Provides a method that @@ -771,9 +773,9 @@ func (a *Agent) Reload(newConfig *Config) error { a.config.TLSConfig = newConfig.TLSConfig.Copy() if newConfig.TLSConfig.IsEmpty() { - a.logger.Println("[WARN] Downgrading agent's existing TLS configuration to plaintext") + a.logger.Println("[WARN] agent: Downgrading agent's existing TLS configuration to plaintext") } else { - a.logger.Println("[INFO] Upgrading from plaintext configuration to TLS") + a.logger.Println("[INFO] agent: Upgrading from plaintext configuration to TLS") } // Reload the TLS configuration for the client or server, depending on how @@ -781,7 +783,7 @@ func (a *Agent) Reload(newConfig *Config) error { if s := a.Server(); s != nil { err := s.ReloadTLSConnections(a.config.TLSConfig) if err != nil { - a.logger.Printf("[WARN] agent: Issue reloading the server's TLS Configuration, consider a full system restart: %v", err.Error()) + a.logger.Printf("[WARN] agent: error reloading the server's TLS configuration: %v", err.Error()) return err } } @@ -789,7 +791,7 @@ func (a *Agent) Reload(newConfig *Config) error { err := c.ReloadTLSConnections(a.config.TLSConfig) if err != nil { - a.logger.Printf("[ERR] agent: Issue reloading the client's TLS Configuration, consider a full system restart: %v", err.Error()) + a.logger.Printf("[WARN] agent: error reloading the server's TLS configuration: %v", err.Error()) return err } } @@ -797,7 +799,7 @@ func (a *Agent) Reload(newConfig *Config) error { return nil } -// GetConfigCopy creates a replica of the agent's config, excluding locks +// GetConfig creates a locked reference to the agent's config func (a *Agent) GetConfig() *Config { a.configLock.Lock() defer a.configLock.Unlock() diff --git a/command/agent/agent_test.go b/command/agent/agent_test.go index 023060401c9..6042d6902f6 100644 --- a/command/agent/agent_test.go +++ b/command/agent/agent_test.go @@ -888,9 +888,8 @@ func TestServer_ShouldReload_ReturnFalseForNoChanges(t *testing.T) { config: agentConfig, } - shouldReload, reloadFunc := agent.ShouldReload(sameAgentConfig) + shouldReload := agent.ShouldReload(sameAgentConfig) assert.False(shouldReload) - assert.Nil(reloadFunc) } func TestServer_ShouldReload_ReturnTrueForConfigChanges(t *testing.T) { @@ -936,7 +935,6 @@ func TestServer_ShouldReload_ReturnTrueForConfigChanges(t *testing.T) { config: agentConfig, } - shouldReload, reloadFunc := agent.ShouldReload(newConfig) + shouldReload := agent.ShouldReload(newConfig) assert.True(shouldReload) - assert.NotNil(reloadFunc) } diff --git a/command/agent/command.go b/command/agent/command.go index bcaf89bbd65..cc363e5b395 100644 --- a/command/agent/command.go +++ b/command/agent/command.go @@ -598,21 +598,17 @@ WAIT: } } -func (c *Command) reloadHTTPServerOnConfigChange(newConfig *Config) error { +func (c *Command) reloadHTTPServer(newConfig *Config) error { c.agent.logger.Println("[INFO] agent: Reloading HTTP server with new TLS configuration") - err := c.httpServer.Shutdown() - if err != nil { - return err - } + c.httpServerLock.Lock() + defer c.httpServerLock.Unlock() + + c.httpServer.Shutdown() - // Wait some time to ensure a clean shutdown - time.Sleep(5 * time.Second) http, err := NewHTTPServer(c.agent, c.agent.config) if err != nil { return err } - c.httpServerLock.Lock() - defer c.httpServerLock.Unlock() c.httpServer = http return nil @@ -640,18 +636,18 @@ func (c *Command) handleReload() { newConf.LogLevel = c.agent.GetConfig().LogLevel } - shouldReload, reloadFunc := c.agent.ShouldReload(newConf) - if shouldReload && reloadFunc != nil { + shouldReload := c.agent.ShouldReload(newConf) + if shouldReload { // Reloads configuration for an agent running in both client and server mode - err := reloadFunc(newConf) + err := c.agent.Reload(newConf) if err != nil { c.agent.logger.Printf("[ERR] agent: failed to reload the config: %v", err) return } - err = c.reloadHTTPServerOnConfigChange(newConf) + c.reloadHTTPServer(newConf) if err != nil { - c.agent.logger.Printf("[ERR] agent: failed to reload the config: %v", err) + c.agent.logger.Printf("[ERR] http: failed to reload the config: %v", err) return } } diff --git a/command/agent/http.go b/command/agent/http.go index 5e7370296ed..81ce5816636 100644 --- a/command/agent/http.go +++ b/command/agent/http.go @@ -47,11 +47,12 @@ var ( // HTTPServer is used to wrap an Agent and expose it over an HTTP interface type HTTPServer struct { - agent *Agent - mux *http.ServeMux - listener net.Listener - logger *log.Logger - Addr string + agent *Agent + mux *http.ServeMux + listener net.Listener + listenerCh chan struct{} + logger *log.Logger + Addr string } // NewHTTPServer starts new HTTP server over the agent @@ -89,11 +90,12 @@ func NewHTTPServer(agent *Agent, config *Config) (*HTTPServer, error) { // Create the server srv := &HTTPServer{ - agent: agent, - mux: mux, - listener: ln, - logger: agent.logger, - Addr: ln.Addr().String(), + agent: agent, + mux: mux, + listener: ln, + listenerCh: make(chan struct{}), + logger: agent.logger, + Addr: ln.Addr().String(), } srv.registerHandlers(config.EnableDebug) @@ -103,7 +105,10 @@ func NewHTTPServer(agent *Agent, config *Config) (*HTTPServer, error) { return nil, err } - go http.Serve(ln, gzip(mux)) + go func() { + defer close(srv.listenerCh) + http.Serve(ln, gzip(mux)) + }() return srv, nil } @@ -126,12 +131,12 @@ func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { } // Shutdown is used to shutdown the HTTP server -func (s *HTTPServer) Shutdown() error { +func (s *HTTPServer) Shutdown() { if s != nil { s.logger.Printf("[DEBUG] http: Shutting down http server") - return s.listener.Close() + s.listener.Close() + <-s.listenerCh // block until http.Serve has returned. } - return nil } // registerHandlers is used to attach our handlers to the mux diff --git a/nomad/raft_rpc.go b/nomad/raft_rpc.go index 4cfcca4009c..cce0c0e5243 100644 --- a/nomad/raft_rpc.go +++ b/nomad/raft_rpc.go @@ -51,7 +51,7 @@ func (l *RaftLayer) Handoff(c net.Conn, ctx context.Context) error { case <-l.closeCh: return fmt.Errorf("Raft RPC layer closed") case <-ctx.Done(): - return fmt.Errorf("[INFO] nomad.rpc: Closing server RPC connection") + return nil } } diff --git a/nomad/rpc.go b/nomad/rpc.go index 8deedd19416..cab276bf8e3 100644 --- a/nomad/rpc.go +++ b/nomad/rpc.go @@ -83,18 +83,24 @@ func (s *Server) listen(ctx context.Context) { if s.shutdown { return } + + select { + case <-ctx.Done(): + return + } + s.logger.Printf("[ERR] nomad.rpc: failed to accept RPC conn: %v", err) continue } - go s.handleConn(conn, false, ctx) + go s.handleConn(ctx, conn, false) metrics.IncrCounter([]string{"nomad", "rpc", "accept_conn"}, 1) } } // handleConn is used to determine if this is a Raft or // Nomad type RPC connection and invoke the correct handler -func (s *Server) handleConn(conn net.Conn, isTLS bool, ctx context.Context) { +func (s *Server) handleConn(ctx context.Context, conn net.Conn, isTLS bool) { // Read a single byte buf := make([]byte, 1) if _, err := conn.Read(buf); err != nil { @@ -117,14 +123,14 @@ func (s *Server) handleConn(conn net.Conn, isTLS bool, ctx context.Context) { // Switch on the byte switch RPCType(buf[0]) { case rpcNomad: - s.handleNomadConn(conn, ctx) + s.handleNomadConn(ctx, conn) case rpcRaft: metrics.IncrCounter([]string{"nomad", "rpc", "raft_handoff"}, 1) s.raftLayer.Handoff(conn, ctx) case rpcMultiplex: - s.handleMultiplex(conn, ctx) + s.handleMultiplex(ctx, conn) case rpcTLS: if s.rpcTLS == nil { @@ -133,7 +139,7 @@ func (s *Server) handleConn(conn net.Conn, isTLS bool, ctx context.Context) { return } conn = tls.Server(conn, s.rpcTLS) - s.handleConn(conn, true, ctx) + s.handleConn(ctx, conn, true) default: s.logger.Printf("[ERR] nomad.rpc: unrecognized RPC byte: %v", buf[0]) @@ -144,7 +150,7 @@ func (s *Server) handleConn(conn net.Conn, isTLS bool, ctx context.Context) { // handleMultiplex is used to multiplex a single incoming connection // using the Yamux multiplexer -func (s *Server) handleMultiplex(conn net.Conn, ctx context.Context) { +func (s *Server) handleMultiplex(ctx context.Context, conn net.Conn) { defer conn.Close() conf := yamux.DefaultConfig() conf.LogOutput = s.config.LogOutput @@ -157,12 +163,12 @@ func (s *Server) handleMultiplex(conn net.Conn, ctx context.Context) { } return } - go s.handleNomadConn(sub, ctx) + go s.handleNomadConn(ctx, sub) } } // handleNomadConn is used to service a single Nomad RPC connection -func (s *Server) handleNomadConn(conn net.Conn, ctx context.Context) { +func (s *Server) handleNomadConn(ctx context.Context, conn net.Conn) { defer conn.Close() rpcCodec := NewServerCodec(conn) for { diff --git a/nomad/server.go b/nomad/server.go index 0f926f01e6e..6a6a7eb2bf2 100644 --- a/nomad/server.go +++ b/nomad/server.go @@ -437,7 +437,7 @@ func (s *Server) ReloadTLSConnections(newTLSConfig *config.TLSConfig) error { s.config.LogOutput) s.raftTransport = trans - s.logger.Printf("[INFO] nomad: finished reloading server connections") + s.logger.Printf("[DEBUG] nomad: finished reloading server connections") return nil } diff --git a/nomad/structs/config/tls.go b/nomad/structs/config/tls.go index 2b2f72a8c18..d4e3e04abf7 100644 --- a/nomad/structs/config/tls.go +++ b/nomad/structs/config/tls.go @@ -185,12 +185,14 @@ func (t *TLSConfig) Merge(b *TLSConfig) *TLSConfig { // Equals compares the fields of two TLS configuration objects, returning a // boolean indicating if they are the same. +// NewConfig should never be nil- calling code is responsible for walways +// passing a valid TLSConfig object func (t *TLSConfig) Equals(newConfig *TLSConfig) bool { - if t == nil && newConfig == nil { - return true + if t == nil { + return false } - if t != nil && newConfig == nil { + if t == nil && newConfig != nil { return false }