From 258ed556ba3203ac26f9c32c8ae762acb3a588e6 Mon Sep 17 00:00:00 2001 From: Alex Dadgar Date: Wed, 1 Feb 2017 14:20:14 -0800 Subject: [PATCH] Vault Client on Server handles SIGHUP This PR allows the Vault client on the server to handle a SIGHUP. This allows updating the Vault token and any other configuration without downtime. --- command/agent/agent.go | 86 +++++++++++++++++++++------------------- command/agent/command.go | 12 ++++++ nomad/server.go | 18 +++++++++ nomad/server_test.go | 25 ++++++++++++ nomad/vault.go | 24 ++++++++--- nomad/vault_test.go | 34 ++++++++++++++++ nomad/vault_testing.go | 1 + 7 files changed, 155 insertions(+), 45 deletions(-) diff --git a/command/agent/agent.go b/command/agent/agent.go index 05aaa4734d9..ca4df6672e9 100644 --- a/command/agent/agent.go +++ b/command/agent/agent.go @@ -88,56 +88,56 @@ func NewAgent(config *Config, logOutput io.Writer) (*Agent, error) { return a, nil } -// serverConfig is used to generate a new server configuration struct -// for initializing a nomad server. -func (a *Agent) serverConfig() (*nomad.Config, error) { - conf := a.config.NomadConfig +// convertServerConfig takes an agent config and log output and returns a Nomad +// Config. +func convertServerConfig(agentConfig *Config, logOutput io.Writer) (*nomad.Config, error) { + conf := agentConfig.NomadConfig if conf == nil { conf = nomad.DefaultConfig() } - conf.LogOutput = a.logOutput - conf.DevMode = a.config.DevMode - conf.Build = fmt.Sprintf("%s%s", a.config.Version, a.config.VersionPrerelease) - if a.config.Region != "" { - conf.Region = a.config.Region + conf.LogOutput = logOutput + conf.DevMode = agentConfig.DevMode + conf.Build = fmt.Sprintf("%s%s", agentConfig.Version, agentConfig.VersionPrerelease) + if agentConfig.Region != "" { + conf.Region = agentConfig.Region } - if a.config.Datacenter != "" { - conf.Datacenter = a.config.Datacenter + if agentConfig.Datacenter != "" { + conf.Datacenter = agentConfig.Datacenter } - if a.config.NodeName != "" { - conf.NodeName = a.config.NodeName + if agentConfig.NodeName != "" { + conf.NodeName = agentConfig.NodeName } - if a.config.Server.BootstrapExpect > 0 { - if a.config.Server.BootstrapExpect == 1 { + if agentConfig.Server.BootstrapExpect > 0 { + if agentConfig.Server.BootstrapExpect == 1 { conf.Bootstrap = true } else { - atomic.StoreInt32(&conf.BootstrapExpect, int32(a.config.Server.BootstrapExpect)) + atomic.StoreInt32(&conf.BootstrapExpect, int32(agentConfig.Server.BootstrapExpect)) } } - if a.config.DataDir != "" { - conf.DataDir = filepath.Join(a.config.DataDir, "server") + if agentConfig.DataDir != "" { + conf.DataDir = filepath.Join(agentConfig.DataDir, "server") } - if a.config.Server.DataDir != "" { - conf.DataDir = a.config.Server.DataDir + if agentConfig.Server.DataDir != "" { + conf.DataDir = agentConfig.Server.DataDir } - if a.config.Server.ProtocolVersion != 0 { - conf.ProtocolVersion = uint8(a.config.Server.ProtocolVersion) + if agentConfig.Server.ProtocolVersion != 0 { + conf.ProtocolVersion = uint8(agentConfig.Server.ProtocolVersion) } - if a.config.Server.NumSchedulers != 0 { - conf.NumSchedulers = a.config.Server.NumSchedulers + if agentConfig.Server.NumSchedulers != 0 { + conf.NumSchedulers = agentConfig.Server.NumSchedulers } - if len(a.config.Server.EnabledSchedulers) != 0 { - conf.EnabledSchedulers = a.config.Server.EnabledSchedulers + if len(agentConfig.Server.EnabledSchedulers) != 0 { + conf.EnabledSchedulers = agentConfig.Server.EnabledSchedulers } // Set up the bind addresses - rpcAddr, err := net.ResolveTCPAddr("tcp", a.config.normalizedAddrs.RPC) + rpcAddr, err := net.ResolveTCPAddr("tcp", agentConfig.normalizedAddrs.RPC) if err != nil { - return nil, fmt.Errorf("Failed to parse RPC address %q: %v", a.config.normalizedAddrs.RPC, err) + return nil, fmt.Errorf("Failed to parse RPC address %q: %v", agentConfig.normalizedAddrs.RPC, err) } - serfAddr, err := net.ResolveTCPAddr("tcp", a.config.normalizedAddrs.Serf) + serfAddr, err := net.ResolveTCPAddr("tcp", agentConfig.normalizedAddrs.Serf) if err != nil { - return nil, fmt.Errorf("Failed to parse Serf address %q: %v", a.config.normalizedAddrs.Serf, err) + return nil, fmt.Errorf("Failed to parse Serf address %q: %v", agentConfig.normalizedAddrs.Serf, err) } conf.RPCAddr.Port = rpcAddr.Port conf.RPCAddr.IP = rpcAddr.IP @@ -145,20 +145,20 @@ func (a *Agent) serverConfig() (*nomad.Config, error) { conf.SerfConfig.MemberlistConfig.BindAddr = serfAddr.IP.String() // Set up the advertise addresses - rpcAddr, err = net.ResolveTCPAddr("tcp", a.config.AdvertiseAddrs.RPC) + rpcAddr, err = net.ResolveTCPAddr("tcp", agentConfig.AdvertiseAddrs.RPC) if err != nil { - return nil, fmt.Errorf("Failed to parse RPC advertise address %q: %v", a.config.AdvertiseAddrs.RPC, err) + return nil, fmt.Errorf("Failed to parse RPC advertise address %q: %v", agentConfig.AdvertiseAddrs.RPC, err) } - serfAddr, err = net.ResolveTCPAddr("tcp", a.config.AdvertiseAddrs.Serf) + serfAddr, err = net.ResolveTCPAddr("tcp", agentConfig.AdvertiseAddrs.Serf) if err != nil { - return nil, fmt.Errorf("Failed to parse Serf advertise address %q: %v", a.config.AdvertiseAddrs.Serf, err) + return nil, fmt.Errorf("Failed to parse Serf advertise address %q: %v", agentConfig.AdvertiseAddrs.Serf, err) } conf.RPCAdvertise = rpcAddr conf.SerfConfig.MemberlistConfig.AdvertiseAddr = serfAddr.IP.String() conf.SerfConfig.MemberlistConfig.AdvertisePort = serfAddr.Port // Set up gc threshold and heartbeat grace period - if gcThreshold := a.config.Server.NodeGCThreshold; gcThreshold != "" { + if gcThreshold := agentConfig.Server.NodeGCThreshold; gcThreshold != "" { dur, err := time.ParseDuration(gcThreshold) if err != nil { return nil, err @@ -166,7 +166,7 @@ func (a *Agent) serverConfig() (*nomad.Config, error) { conf.NodeGCThreshold = dur } - if heartbeatGrace := a.config.Server.HeartbeatGrace; heartbeatGrace != "" { + if heartbeatGrace := agentConfig.Server.HeartbeatGrace; heartbeatGrace != "" { dur, err := time.ParseDuration(heartbeatGrace) if err != nil { return nil, err @@ -174,20 +174,26 @@ func (a *Agent) serverConfig() (*nomad.Config, error) { conf.HeartbeatGrace = dur } - if *a.config.Consul.AutoAdvertise && a.config.Consul.ServerServiceName == "" { + if *agentConfig.Consul.AutoAdvertise && agentConfig.Consul.ServerServiceName == "" { return nil, fmt.Errorf("server_service_name must be set when auto_advertise is enabled") } // Add the Consul and Vault configs - conf.ConsulConfig = a.config.Consul - conf.VaultConfig = a.config.Vault + conf.ConsulConfig = agentConfig.Consul + conf.VaultConfig = agentConfig.Vault // Set the TLS config - conf.TLSConfig = a.config.TLSConfig + conf.TLSConfig = agentConfig.TLSConfig return conf, nil } +// serverConfig is used to generate a new server configuration struct +// for initializing a nomad server. +func (a *Agent) serverConfig() (*nomad.Config, error) { + return convertServerConfig(a.config, a.logOutput) +} + // clientConfig is used to generate a new client configuration struct // for initializing a Nomad client. func (a *Agent) clientConfig() (*clientconfig.Config, error) { diff --git a/command/agent/command.go b/command/agent/command.go index 0e7b368fa73..85ed1938005 100644 --- a/command/agent/command.go +++ b/command/agent/command.go @@ -601,6 +601,18 @@ func (c *Command) handleReload(config *Config) *Config { // Keep the current log level newConf.LogLevel = config.LogLevel } + + if s := c.agent.Server(); s != nil { + sconf, err := convertServerConfig(newConf, c.logOutput) + if err != nil { + c.agent.logger.Printf("[ERR] agent: failed to convert server config: %v", err) + } else { + if err := s.Reload(sconf); err != nil { + c.agent.logger.Printf("[ERR] agent: reloading server config failed: %v", err) + } + } + } + return newConf } diff --git a/nomad/server.go b/nomad/server.go index ef7995d581b..d2c2582f95c 100644 --- a/nomad/server.go +++ b/nomad/server.go @@ -411,6 +411,24 @@ func (s *Server) Leave() error { return nil } +// Reload handles a config reload. Not all config fields can handle a reload. +func (s *Server) Reload(config *Config) error { + if config == nil { + return fmt.Errorf("Reload given a nil config") + } + + var mErr multierror.Error + + // Handle the Vault reload. Vault should never be nil but just guard. + if s.vault != nil { + if err := s.vault.SetConfig(config.VaultConfig); err != nil { + multierror.Append(&mErr, err) + } + } + + return mErr.ErrorOrNil() +} + // setupBootstrapHandler() creates the closure necessary to support a Consul // fallback handler. func (s *Server) setupBootstrapHandler() error { diff --git a/nomad/server_test.go b/nomad/server_test.go index c66a224984b..94d245cfd1c 100644 --- a/nomad/server_test.go +++ b/nomad/server_test.go @@ -10,6 +10,7 @@ import ( "time" "github.com/hashicorp/nomad/command/agent/consul" + "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/testutil" ) @@ -136,3 +137,27 @@ func TestServer_Regions(t *testing.T) { t.Fatalf("err: %v", err) }) } + +func TestServer_Reload_Vault(t *testing.T) { + s1 := testServer(t, func(c *Config) { + c.Region = "region1" + }) + defer s1.Shutdown() + + if s1.vault.Running() { + t.Fatalf("Vault client should not be running") + } + + tr := true + config := s1.config + config.VaultConfig.Enabled = &tr + config.VaultConfig.Token = structs.GenerateUUID() + + if err := s1.Reload(config); err != nil { + t.Fatalf("Reload failed: %v", err) + } + + if !s1.vault.Running() { + t.Fatalf("Vault client should be running") + } +} diff --git a/nomad/vault.go b/nomad/vault.go index a78c8221039..898825d6ea6 100644 --- a/nomad/vault.go +++ b/nomad/vault.go @@ -132,6 +132,9 @@ type VaultClient interface { // Stop is used to stop token renewal Stop() + + // Running returns whether the Vault client is running + Running() bool } // PurgeVaultAccessor is called to remove VaultAccessors from the system. If @@ -254,6 +257,12 @@ func (v *vaultClient) Stop() { } } +func (v *vaultClient) Running() bool { + v.l.Lock() + defer v.l.Unlock() + return v.running +} + // SetActive activates or de-activates the Vault client. When active, token // creation/lookup/revocation operation are allowed. All queued revocations are // cancelled if set un-active as it is assumed another instances is taking over @@ -298,10 +307,8 @@ func (v *vaultClient) SetConfig(config *config.VaultConfig) error { v.l.Lock() defer v.l.Unlock() - // Store the new config - v.config = config - - if v.config.IsEnabled() { + // Kill any background routintes + if v.running { // Stop accepting any new request v.connEstablished = false @@ -309,16 +316,23 @@ func (v *vaultClient) SetConfig(config *config.VaultConfig) error { v.tomb.Kill(nil) v.tomb.Wait() v.tomb = &tomb.Tomb{} + v.running = false + } + // Store the new config + v.config = config + + // Check if we should relaunch + if v.config.IsEnabled() { // Rebuild the client if err := v.buildClient(); err != nil { - v.l.Unlock() return err } // Launch the required goroutines v.tomb.Go(wrapNilError(v.establishConnection)) v.tomb.Go(wrapNilError(v.revokeDaemon)) + v.running = true } return nil diff --git a/nomad/vault_test.go b/nomad/vault_test.go index cfa639e719f..38c9ceea7e6 100644 --- a/nomad/vault_test.go +++ b/nomad/vault_test.go @@ -364,6 +364,40 @@ func TestVaultClient_SetConfig(t *testing.T) { } } +// Test that we can disable vault +func TestVaultClient_SetConfig_Disable(t *testing.T) { + v := testutil.NewTestVault(t).Start() + defer v.Stop() + + logger := log.New(os.Stderr, "", log.LstdFlags) + client, err := NewVaultClient(v.Config, logger, nil) + if err != nil { + t.Fatalf("failed to build vault client: %v", err) + } + defer client.Stop() + + waitForConnection(client, t) + + if client.tokenData == nil || len(client.tokenData.Policies) != 1 { + t.Fatalf("unexpected token: %v", client.tokenData) + } + + // Disable vault + f := false + config := config.VaultConfig{ + Enabled: &f, + } + + // Update the config + if err := client.SetConfig(&config); err != nil { + t.Fatalf("SetConfig failed: %v", err) + } + + if client.Enabled() || client.Running() { + t.Fatalf("SetConfig should have stopped client") + } +} + func TestVaultClient_RenewalLoop(t *testing.T) { v := testutil.NewTestVault(t).Start() defer v.Stop() diff --git a/nomad/vault_testing.go b/nomad/vault_testing.go index ebc163f9db2..efeea31bb5d 100644 --- a/nomad/vault_testing.go +++ b/nomad/vault_testing.go @@ -137,3 +137,4 @@ func (v *TestVaultClient) RevokeTokens(ctx context.Context, accessors []*structs func (v *TestVaultClient) Stop() {} func (v *TestVaultClient) SetActive(enabled bool) {} func (v *TestVaultClient) SetConfig(config *config.VaultConfig) error { return nil } +func (v *TestVaultClient) Running() bool { return true }