Skip to content

Commit

Permalink
vault: fix data races
Browse files Browse the repository at this point in the history
  • Loading branch information
schmichael committed Mar 27, 2019
1 parent 863f836 commit 516def8
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 11 deletions.
22 changes: 15 additions & 7 deletions client/vaultclient/vaultclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,28 +188,36 @@ func (c *vaultClient) isTracked(id string) bool {
return ok
}

// isRunning returns true if the client is running.
func (c *vaultClient) isRunning() bool {
c.lock.RLock()
defer c.lock.RUnlock()
return c.running
}

// Starts the renewal loop of vault client
func (c *vaultClient) Start() {
c.lock.Lock()
defer c.lock.Unlock()

if !c.config.IsEnabled() || c.running {
return
}

c.lock.Lock()
c.running = true
c.lock.Unlock()

go c.run()
}

// Stops the renewal loop of vault client
func (c *vaultClient) Stop() {
c.lock.Lock()
defer c.lock.Unlock()

if !c.config.IsEnabled() || !c.running {
return
}

c.lock.Lock()
defer c.lock.Unlock()

c.running = false
close(c.stopCh)
}
Expand All @@ -229,7 +237,7 @@ func (c *vaultClient) DeriveToken(alloc *structs.Allocation, taskNames []string)
if !c.config.IsEnabled() {
return nil, fmt.Errorf("vault client not enabled")
}
if !c.running {
if !c.isRunning() {
return nil, fmt.Errorf("vault client is not running")
}

Expand Down Expand Up @@ -499,7 +507,7 @@ func (c *vaultClient) run() {
}

var renewalCh <-chan time.Time
for c.config.IsEnabled() && c.running {
for c.config.IsEnabled() && c.isRunning() {
// Fetches the candidate for next renewal
renewalReq, renewalTime := c.nextRenewal()
if renewalTime.IsZero() {
Expand Down
14 changes: 10 additions & 4 deletions client/vaultclient/vaultclient_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,11 @@ func TestVaultClient_TokenRenewals(t *testing.T) {
}(errCh)
}

if c.heap.Length() != num {
t.Fatalf("bad: heap length: expected: %d, actual: %d", num, c.heap.Length())
c.lock.Lock()
length := c.heap.Length()
c.lock.Unlock()
if length != num {
t.Fatalf("bad: heap length: expected: %d, actual: %d", num, length)
}

time.Sleep(time.Duration(testutil.TestMultiplier()) * time.Second)
Expand All @@ -88,8 +91,11 @@ func TestVaultClient_TokenRenewals(t *testing.T) {
}
}

if c.heap.Length() != 0 {
t.Fatalf("bad: heap length: expected: 0, actual: %d", c.heap.Length())
c.lock.Lock()
length = c.heap.Length()
c.lock.Unlock()
if length != 0 {
t.Fatalf("bad: heap length: expected: 0, actual: %d", length)
}
}

Expand Down

0 comments on commit 516def8

Please sign in to comment.