diff --git a/client/vaultclient/vaultclient.go b/client/vaultclient/vaultclient.go index fce8dfc4390..ea221bba6c1 100644 --- a/client/vaultclient/vaultclient.go +++ b/client/vaultclient/vaultclient.go @@ -188,28 +188,36 @@ func (c *vaultClient) isTracked(id string) bool { return ok } +// isRunning returns true if the client is running. +func (c *vaultClient) isRunning() bool { + c.lock.RLock() + defer c.lock.RUnlock() + return c.running +} + // Starts the renewal loop of vault client func (c *vaultClient) Start() { + c.lock.Lock() + defer c.lock.Unlock() + if !c.config.IsEnabled() || c.running { return } - c.lock.Lock() c.running = true - c.lock.Unlock() go c.run() } // Stops the renewal loop of vault client func (c *vaultClient) Stop() { + c.lock.Lock() + defer c.lock.Unlock() + if !c.config.IsEnabled() || !c.running { return } - c.lock.Lock() - defer c.lock.Unlock() - c.running = false close(c.stopCh) } @@ -229,7 +237,7 @@ func (c *vaultClient) DeriveToken(alloc *structs.Allocation, taskNames []string) if !c.config.IsEnabled() { return nil, fmt.Errorf("vault client not enabled") } - if !c.running { + if !c.isRunning() { return nil, fmt.Errorf("vault client is not running") } @@ -499,7 +507,7 @@ func (c *vaultClient) run() { } var renewalCh <-chan time.Time - for c.config.IsEnabled() && c.running { + for c.config.IsEnabled() && c.isRunning() { // Fetches the candidate for next renewal renewalReq, renewalTime := c.nextRenewal() if renewalTime.IsZero() { diff --git a/client/vaultclient/vaultclient_test.go b/client/vaultclient/vaultclient_test.go index d8094eb48ab..6f5ef6645cc 100644 --- a/client/vaultclient/vaultclient_test.go +++ b/client/vaultclient/vaultclient_test.go @@ -76,8 +76,11 @@ func TestVaultClient_TokenRenewals(t *testing.T) { }(errCh) } - if c.heap.Length() != num { - t.Fatalf("bad: heap length: expected: %d, actual: %d", num, c.heap.Length()) + c.lock.Lock() + length := c.heap.Length() + c.lock.Unlock() + if length != num { + t.Fatalf("bad: heap length: expected: %d, actual: %d", num, length) } time.Sleep(time.Duration(testutil.TestMultiplier()) * time.Second) @@ -88,8 +91,11 @@ func TestVaultClient_TokenRenewals(t *testing.T) { } } - if c.heap.Length() != 0 { - t.Fatalf("bad: heap length: expected: 0, actual: %d", c.heap.Length()) + c.lock.Lock() + length = c.heap.Length() + c.lock.Unlock() + if length != 0 { + t.Fatalf("bad: heap length: expected: 0, actual: %d", length) } }