diff --git a/client/vaultclient/vaultclient.go b/client/vaultclient/vaultclient.go index e08bd33eaaf..8d9dff28da2 100644 --- a/client/vaultclient/vaultclient.go +++ b/client/vaultclient/vaultclient.go @@ -4,10 +4,12 @@ import ( "container/heap" "fmt" "math/rand" + "net/http" "strings" "sync" "time" + "github.com/armon/go-metrics" hclog "github.com/hashicorp/go-hclog" "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/nomad/structs/config" @@ -157,6 +159,10 @@ func NewVaultClient(config *config.VaultConfig, logger hclog.Logger, tokenDerive return nil, err } + client.SetHeaders(http.Header{ + "User-Agent": []string{"hashicorp/nomad"}, + }) + c.client = client return c, nil @@ -298,6 +304,7 @@ func (c *vaultClient) RenewToken(token string, increment int) (<-chan error, err // error channel. if err := c.renew(renewalReq); err != nil { c.logger.Error("error during renewal of token", "error", err) + metrics.IncrCounter([]string{"client", "vault", "renew_token_failure"}, 1) return nil, err } @@ -335,6 +342,7 @@ func (c *vaultClient) RenewLease(leaseId string, increment int) (<-chan error, e // Renew the secret and send any error to the dedicated error channel if err := c.renew(renewalReq); err != nil { c.logger.Error("error during renewal of lease", "error", err) + metrics.IncrCounter([]string{"client", "vault", "renew_lease_error"}, 1) return nil, err } @@ -531,6 +539,7 @@ func (c *vaultClient) run() { case <-renewalCh: if err := c.renew(renewalReq); err != nil { c.logger.Error("error renewing token", "error", err) + metrics.IncrCounter([]string{"client", "vault", "renew_token_error"}, 1) } case <-c.updateCh: continue diff --git a/nomad/server.go b/nomad/server.go index a3ea18860cb..89da7a9e4b9 100644 --- a/nomad/server.go +++ b/nomad/server.go @@ -1433,6 +1433,7 @@ func (s *Server) Stats() map[string]map[string]string { "raft": s.raft.Stats(), "serf": s.serf.Stats(), "runtime": stats.RuntimeStats(), + "vault": s.vault.Stats(), } return stats diff --git a/nomad/vault.go b/nomad/vault.go index 91fc30df0c5..cedfd056402 100644 --- a/nomad/vault.go +++ b/nomad/vault.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "math/rand" + "strconv" "sync" "sync/atomic" "time" @@ -127,7 +128,7 @@ type VaultClient interface { Running() bool // Stats returns the Vault clients statistics - Stats() *VaultStats + Stats() map[string]string // EmitStats emits that clients statistics at the given period until stopCh // is called. @@ -140,6 +141,12 @@ type VaultStats struct { // TrackedForRevoke is the count of tokens that are being tracked to be // revoked since they could not be immediately revoked. TrackedForRevoke int + + // TokenTTL is the time-to-live duration for the current token + TokenTTL time.Duration + + // TokenExpiry Time is the recoreded expiry time of the current token + TokenExpiry time.Time } // PurgeVaultAccessor is called to remove VaultAccessors from the system. If @@ -200,19 +207,20 @@ type vaultClient struct { // running indicates whether the vault client is started. running bool + // renewLoopActive indicates whether the renewal goroutine is running + // It should be accessed and updated atomically + // used for testing purposes only + renewLoopActive int32 + // childTTL is the TTL for child tokens. childTTL string - // lastRenewed is the time the token was last renewed - lastRenewed time.Time + // currentExpiration is the time the current token lease expires + currentExpiration time.Time tomb *tomb.Tomb logger log.Logger - // stats stores the stats - stats *VaultStats - statsLock sync.RWMutex - // l is used to lock the configuration aspects of the client such that // multiple callers can't cause conflicting config updates l sync.Mutex @@ -236,7 +244,6 @@ func NewVaultClient(c *config.VaultConfig, logger log.Logger, purgeFn PurgeVault revoking: make(map[*structs.VaultAccessor]time.Time), purgeFn: purgeFn, tomb: &tomb.Tomb{}, - stats: new(VaultStats), } if v.config.IsEnabled() { @@ -456,9 +463,16 @@ OUTER: v.l.Unlock() } +func (v *vaultClient) isRenewLoopActive() bool { + return atomic.LoadInt32(&v.renewLoopActive) == 1 +} + // renewalLoop runs the renew loop. This should only be called if we are given a // non-root token. func (v *vaultClient) renewalLoop() { + atomic.StoreInt32(&v.renewLoopActive, 1) + defer atomic.StoreInt32(&v.renewLoopActive, 0) + // Create the renewal timer and set initial duration to zero so it fires // immediately authRenewTimer := time.NewTimer(0) @@ -473,14 +487,12 @@ func (v *vaultClient) renewalLoop() { return case <-authRenewTimer.C: // Renew the token and determine the new expiration - err := v.renew() - currentExpiration := v.lastRenewed.Add(time.Duration(v.tokenData.CreationTTL) * time.Second) + recoverable, err := v.renew() + currentExpiration := v.currentExpiration // Successfully renewed if err == nil { - // If we take the expiration (lastRenewed + auth duration) and - // subtract the current time, we get a duration until expiry. - // Set the timer to poke us after half of that time is up. + // Attempt to renew the token at half the expiration time durationUntilRenew := currentExpiration.Sub(time.Now()) / 2 v.logger.Info("successfully renewed token", "next_renewal", durationUntilRenew) @@ -491,33 +503,15 @@ func (v *vaultClient) renewalLoop() { break } - // Back off, increasing the amount of backoff each time. There are some rules: - // - // * If we have an existing authentication that is going to expire, - // never back off more than half of the amount of time remaining - // until expiration - // * Never back off more than 30 seconds multiplied by a random - // value between 1 and 2 - // * Use randomness so that many clients won't keep hitting Vault - // at the same time - - // Set base values and add some backoff - - v.logger.Warn("got error or bad auth, so backing off", "error", err) - switch { - case backoff < 5: - backoff = 5 - case backoff >= 24: - backoff = 30 - default: - backoff = backoff * 1.25 - } + metrics.IncrCounter([]string{"nomad", "vault", "renew_failed"}, 1) + v.logger.Warn("got error or bad auth, so backing off", "error", err, "recoverable", recoverable) - // Add randomness - backoff = backoff * (1.0 + rand.Float64()) + if !recoverable { + return + } - maxBackoff := currentExpiration.Sub(time.Now()) / 2 - if maxBackoff < 0 { + backoff = nextBackoff(backoff, currentExpiration) + if backoff < 0 { // We have failed to renew the token past its expiration. Stop // renewing with Vault. v.logger.Error("failed to renew Vault token before lease expiration. Shutting down Vault client") @@ -526,9 +520,6 @@ func (v *vaultClient) renewalLoop() { v.connEstablishedErr = err v.l.Unlock() return - - } else if backoff > maxBackoff.Seconds() { - backoff = maxBackoff.Seconds() } durationUntilRetry := time.Duration(backoff) * time.Second @@ -539,30 +530,82 @@ func (v *vaultClient) renewalLoop() { } } +// nextBackoff returns the delay for the next auto renew interval, in seconds. +// Returns negative value if past expiration +// +// It should increase the amount of backoff each time, with the following rules: +// +// * If token expired already despite earlier renewal attempts, +// back off for 1 minute + jitter +// * If we have an existing authentication that is going to expire, +// never back off more than half of the amount of time remaining +// until expiration (with 5s floor) +// * Never back off more than 30 seconds multiplied by a random +// value between 1 and 2 +// * Use randomness so that many clients won't keep hitting Vault +// at the same time +func nextBackoff(backoff float64, expiry time.Time) float64 { + maxBackoff := time.Until(expiry) / 2 + + if maxBackoff < 0 { + // expiry passed + return 60 * (1.0 + rand.Float64()) + } + + switch { + case backoff >= 24: + backoff = 30 + default: + backoff = backoff * 1.25 + } + + // Add randomness + backoff = backoff * (1.0 + rand.Float64()) + + if backoff > maxBackoff.Seconds() { + backoff = maxBackoff.Seconds() + } + + if backoff < 5 { + backoff = 5 + } + + return backoff +} + // renew attempts to renew our Vault token. If the renewal fails, an error is -// returned. This method updates the lastRenewed time -func (v *vaultClient) renew() error { +// returned. The boolean indicates whether it's safe to attempt to renew again. +// This method updates the currentExpiration time +func (v *vaultClient) renew() (bool, error) { + // Track how long the request takes + defer metrics.MeasureSince([]string{"nomad", "vault", "renew"}, time.Now()) + // Attempt to renew the token secret, err := v.auth.RenewSelf(v.tokenData.CreationTTL) if err != nil { - return err + // Check if there is a permission denied + recoverable := !structs.VaultUnrecoverableError.MatchString(err.Error()) + return recoverable, fmt.Errorf("failed to renew the vault token: %v", err) } + if secret == nil { // It's possible for RenewSelf to return (nil, nil) if the // response body from Vault is empty. - return fmt.Errorf("renewal failed: empty response from vault") + return true, fmt.Errorf("renewal failed: empty response from vault") } + // these treated as transient errors, where can keep renewing auth := secret.Auth if auth == nil { - return fmt.Errorf("renewal successful but not auth information returned") + return true, fmt.Errorf("renewal successful but not auth information returned") } else if auth.LeaseDuration == 0 { - return fmt.Errorf("renewal successful but no lease duration returned") + return true, fmt.Errorf("renewal successful but no lease duration returned") } - v.lastRenewed = time.Now() + v.currentExpiration = time.Now().Add(time.Duration(auth.LeaseDuration) * time.Second) + v.logger.Debug("successfully renewed server token") - return nil + return true, nil } // getWrappingFn returns an appropriate wrapping function for Nomad Servers @@ -607,7 +650,6 @@ func (v *vaultClient) parseSelfToken() error { if err := mapstructure.WeakDecode(self.Data, &data); err != nil { return fmt.Errorf("failed to parse Vault token's data block: %v", err) } - root := false for _, p := range data.Policies { if p == "root" { @@ -615,10 +657,9 @@ func (v *vaultClient) parseSelfToken() error { break } } - - // Store the token data data.Root = root v.tokenData = &data + v.currentExpiration = time.Now().Add(time.Duration(data.TTL) * time.Second) // The criteria that must be met for the token to be valid are as follows: // 1) If token is non-root or is but has a creation ttl @@ -637,7 +678,7 @@ func (v *vaultClient) parseSelfToken() error { var mErr multierror.Error role := v.getRole() - if !root { + if !data.Root { // All non-root tokens must be renewable if !data.Renewable { multierror.Append(&mErr, fmt.Errorf("Vault token is not renewable or root")) @@ -669,7 +710,7 @@ func (v *vaultClient) parseSelfToken() error { } // Check we have the correct capabilities - if err := v.validateCapabilities(role, root); err != nil { + if err := v.validateCapabilities(role, data.Root); err != nil { multierror.Append(&mErr, err) } @@ -899,6 +940,7 @@ func (v *vaultClient) CreateToken(ctx context.Context, a *structs.Allocation, ta // Determine whether it is unrecoverable if err != nil { + err = fmt.Errorf("failed to create an alloc vault token: %v", err) if structs.VaultUnrecoverableError.MatchString(err.Error()) { return secret, err } @@ -1035,13 +1077,11 @@ func (v *vaultClient) RevokeTokens(ctx context.Context, accessors []*structs.Vau // time. func (v *vaultClient) storeForRevocation(accessors []*structs.VaultAccessor) { v.revLock.Lock() - v.statsLock.Lock() + now := time.Now() for _, a := range accessors { v.revoking[a] = now.Add(time.Duration(a.CreationTTL) * time.Second) } - v.stats.TrackedForRevoke = len(v.revoking) - v.statsLock.Unlock() v.revLock.Unlock() } @@ -1162,12 +1202,9 @@ func (v *vaultClient) revokeDaemon() { // Can delete from the tracked list now that we have purged v.revLock.Lock() - v.statsLock.Lock() for _, va := range revoking { delete(v.revoking, va) } - v.stats.TrackedForRevoke = len(v.revoking) - v.statsLock.Unlock() v.revLock.Unlock() } @@ -1199,16 +1236,34 @@ func (v *vaultClient) setLimit(l rate.Limit) { v.limiter = rate.NewLimiter(l, int(l)) } -// Stats is used to query the state of the blocked eval tracker. -func (v *vaultClient) Stats() *VaultStats { +func (v *vaultClient) Stats() map[string]string { + stat := v.stats() + + expireTimeStr := "" + + if !stat.TokenExpiry.IsZero() { + expireTimeStr = stat.TokenExpiry.Format(time.RFC3339) + } + + return map[string]string{ + "tracked_for_revoked": strconv.Itoa(stat.TrackedForRevoke), + "token_ttl": stat.TokenTTL.Round(time.Second).String(), + "token_expire_time": expireTimeStr, + } +} + +func (v *vaultClient) stats() *VaultStats { // Allocate a new stats struct stats := new(VaultStats) - v.statsLock.RLock() - defer v.statsLock.RUnlock() + v.revLock.Lock() + stats.TrackedForRevoke = len(v.revoking) + v.revLock.Unlock() - // Copy all the stats - stats.TrackedForRevoke = v.stats.TrackedForRevoke + stats.TokenExpiry = v.currentExpiration + if !stats.TokenExpiry.IsZero() { + stats.TokenTTL = time.Until(stats.TokenExpiry) + } return stats } @@ -1218,8 +1273,10 @@ func (v *vaultClient) EmitStats(period time.Duration, stopCh chan struct{}) { for { select { case <-time.After(period): - stats := v.Stats() + stats := v.stats() metrics.SetGauge([]string{"nomad", "vault", "distributed_tokens_revoking"}, float32(stats.TrackedForRevoke)) + metrics.SetGauge([]string{"nomad", "vault", "token_ttl"}, float32(stats.TokenTTL/time.Millisecond)) + case <-stopCh: return } diff --git a/nomad/vault_test.go b/nomad/vault_test.go index f7fdbddb82f..089efca3e8d 100644 --- a/nomad/vault_test.go +++ b/nomad/vault_test.go @@ -3,6 +3,7 @@ package nomad import ( "context" "encoding/json" + "errors" "fmt" "math/rand" "reflect" @@ -10,6 +11,9 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/time/rate" "github.com/hashicorp/nomad/helper" @@ -528,6 +532,119 @@ func TestVaultClient_RenewalLoop(t *testing.T) { if ttl == 0 { t.Fatalf("token renewal failed; ttl %v", ttl) } + + if client.currentExpiration.Before(time.Now()) { + t.Fatalf("found current expiration to be in past %s", time.Until(client.currentExpiration)) + } +} + +func TestVaultClientRenewUpdatesExpiration(t *testing.T) { + t.Parallel() + v := testutil.NewTestVault(t) + defer v.Stop() + + // Set the configs token in a new test role + v.Config.Token = defaultTestVaultWhitelistRoleAndToken(v, t, 5) + + // Start the client + logger := testlog.HCLogger(t) + client, err := NewVaultClient(v.Config, logger, nil) + if err != nil { + t.Fatalf("failed to build vault client: %v", err) + } + defer client.Stop() + + // Get the current TTL + a := v.Client.Auth().Token() + s2, err := a.Lookup(v.Config.Token) + if err != nil { + t.Fatalf("failed to lookup token: %v", err) + } + exp0 := time.Now().Add(time.Duration(parseTTLFromLookup(s2, t)) * time.Second) + + time.Sleep(1 * time.Second) + + _, err = client.renew() + require.NoError(t, err) + exp1 := client.currentExpiration + require.True(t, exp0.Before(exp1)) + + time.Sleep(1 * time.Second) + + _, err = client.renew() + require.NoError(t, err) + exp2 := client.currentExpiration + require.True(t, exp1.Before(exp2)) +} + +func TestVaultClient_StopsAfterPermissionError(t *testing.T) { + t.Parallel() + v := testutil.NewTestVault(t) + defer v.Stop() + + // Set the configs token in a new test role + v.Config.Token = defaultTestVaultWhitelistRoleAndToken(v, t, 2) + + // Start the client + logger := testlog.HCLogger(t) + client, err := NewVaultClient(v.Config, logger, nil) + if err != nil { + t.Fatalf("failed to build vault client: %v", err) + } + defer client.Stop() + + time.Sleep(500 * time.Millisecond) + + assert.True(t, client.isRenewLoopActive()) + + // Get the current TTL + a := v.Client.Auth().Token() + assert.NoError(t, a.RevokeSelf("")) + + testutil.WaitForResult(func() (bool, error) { + if !client.isRenewLoopActive() { + return true, nil + } else { + return false, errors.New("renew loop should terminate after token is revoked") + } + }, func(err error) { + t.Fatalf("err: %v", err) + }) +} +func TestVaultClient_LoopsUntilCannotRenew(t *testing.T) { + t.Parallel() + v := testutil.NewTestVault(t) + defer v.Stop() + + // Set the configs token in a new test role + v.Config.Token = defaultTestVaultWhitelistRoleAndToken(v, t, 5) + + // Start the client + logger := testlog.HCLogger(t) + client, err := NewVaultClient(v.Config, logger, nil) + if err != nil { + t.Fatalf("failed to build vault client: %v", err) + } + defer client.Stop() + + // Sleep 8 seconds and ensure we have a non-zero TTL + time.Sleep(8 * time.Second) + + // Get the current TTL + a := v.Client.Auth().Token() + s2, err := a.Lookup(v.Config.Token) + if err != nil { + t.Fatalf("failed to lookup token: %v", err) + } + + ttl := parseTTLFromLookup(s2, t) + if ttl == 0 { + t.Fatalf("token renewal failed; ttl %v", ttl) + } + + if client.currentExpiration.Before(time.Now()) { + t.Fatalf("found current expiration to be in past %s", time.Until(client.currentExpiration)) + } } func parseTTLFromLookup(s *vapi.Secret, t *testing.T) int64 { @@ -1114,7 +1231,7 @@ func TestVaultClient_RevokeTokens_PreEstablishs(t *testing.T) { t.Fatalf("didn't add to revoke loop") } - if client.Stats().TrackedForRevoke != 2 { + if client.stats().TrackedForRevoke != 2 { t.Fatalf("didn't add to revoke loop") } } @@ -1258,3 +1375,42 @@ func waitForConnection(v *vaultClient, t *testing.T) { t.Fatalf("Connection not established") }) } + +func TestVaultClient_nextBackoff(t *testing.T) { + simpleCases := []struct { + name string + initBackoff float64 + + // define range of acceptable backoff values accounting for random factor + rangeMin float64 + rangeMax float64 + }{ + {"simple case", 7.0, 8.7, 17.60}, + {"too low", 2.0, 5.0, 10.0}, + {"too large", 100, 30.0, 60.0}, + } + + for _, c := range simpleCases { + t.Run(c.name, func(t *testing.T) { + b := nextBackoff(c.initBackoff, time.Now().Add(10*time.Hour)) + if !(c.rangeMin <= b && b <= c.rangeMax) { + t.Fatalf("Expected backoff within [%v, %v] but found %v", c.rangeMin, c.rangeMax, b) + } + }) + } + + // some edge cases + t.Run("close to expiry", func(t *testing.T) { + b := nextBackoff(20, time.Now().Add(1100*time.Millisecond)) + if b != 5.0 { + t.Fatalf("Expected backoff is 5 but found %v", b) + } + }) + + t.Run("past expiry", func(t *testing.T) { + b := nextBackoff(20, time.Now().Add(-1100*time.Millisecond)) + if !(60 <= b && b <= 120) { + t.Fatalf("Expected backoff within [%v, %v] but found %v", 60, 120, b) + } + }) +} diff --git a/nomad/vault_testing.go b/nomad/vault_testing.go index d5a361c9e5c..9ac8f30f42e 100644 --- a/nomad/vault_testing.go +++ b/nomad/vault_testing.go @@ -139,5 +139,5 @@ func (v *TestVaultClient) Stop() func (v *TestVaultClient) SetActive(enabled bool) {} func (v *TestVaultClient) SetConfig(config *config.VaultConfig) error { return nil } func (v *TestVaultClient) Running() bool { return true } -func (v *TestVaultClient) Stats() *VaultStats { return new(VaultStats) } +func (v *TestVaultClient) Stats() map[string]string { return map[string]string{} } func (v *TestVaultClient) EmitStats(period time.Duration, stopCh chan struct{}) {}