Skip to content

Commit

Permalink
Move location of quit channel closing in exp manager (#3638)
Browse files Browse the repository at this point in the history
* Move location of quit channel closing in exp manager

If it happens after stopping timers any timers firing before all timers
are stopped will still run the revocation function. With plugin
auto-crash-recovery this could end up instantiating a plugin that could
then try to unwrap a token from a nil token store.

This also plumbs in core so that we can grab a read lock during the
operation and check standby/sealed status before running it (after
grabbing the lock).

* Use context instead of checking core values directly

* Use official Go context in a few key places
  • Loading branch information
jefferai authored Dec 1, 2017
1 parent eed4579 commit 276a230
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 16 deletions.
9 changes: 4 additions & 5 deletions vault/core.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package vault

import (
"context"
"crypto"
"crypto/ecdsa"
"crypto/subtle"
Expand All @@ -18,7 +19,6 @@ import (
"github.com/armon/go-metrics"
log "github.com/mgutz/logxi/v1"

"golang.org/x/net/context"
"google.golang.org/grpc"

"github.com/hashicorp/errwrap"
Expand Down Expand Up @@ -1498,8 +1498,6 @@ func (c *Core) sealInternal() error {
// Signal the standby goroutine to shutdown, wait for completion
close(c.standbyStopCh)

c.requestContext = nil

// Release the lock while we wait to avoid deadlocking
c.stateLock.Unlock()
<-c.standbyDoneCh
Expand Down Expand Up @@ -1536,9 +1534,8 @@ func (c *Core) postUnseal() (retErr error) {
defer metrics.MeasureSince([]string{"core", "post_unseal"}, time.Now())
defer func() {
if retErr != nil {
c.requestContextCancelFunc()
c.preSeal()
} else {
c.requestContext, c.requestContextCancelFunc = context.WithCancel(context.Background())
}
}()
c.logger.Info("core: post-unseal setup starting")
Expand All @@ -1559,6 +1556,8 @@ func (c *Core) postUnseal() (retErr error) {
c.seal.SetRecoveryConfig(nil)
}

c.requestContext, c.requestContextCancelFunc = context.WithCancel(context.Background())

if err := enterprisePostUnseal(c); err != nil {
return err
}
Expand Down
42 changes: 32 additions & 10 deletions vault/expiration.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package vault

import (
"context"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -68,29 +69,36 @@ type ExpirationManager struct {
restoreLocks []*locksutil.LockEntry
restoreLoaded sync.Map
quitCh chan struct{}

coreStateLock *sync.RWMutex
quitContext context.Context
}

// NewExpirationManager creates a new ExpirationManager that is backed
// using a given view, and uses the provided router for revocation.
func NewExpirationManager(router *Router, view *BarrierView, ts *TokenStore, logger log.Logger) *ExpirationManager {
if logger == nil {
logger = log.New("expiration_manager")
}

func NewExpirationManager(c *Core, view *BarrierView) *ExpirationManager {
exp := &ExpirationManager{
router: router,
router: c.router,
idView: view.SubView(leaseViewPrefix),
tokenView: view.SubView(tokenViewPrefix),
tokenStore: ts,
logger: logger,
tokenStore: c.tokenStore,
logger: c.logger,
pending: make(map[string]*time.Timer),

// new instances of the expiration manager will go immediately into
// restore mode
restoreMode: 1,
restoreLocks: locksutil.CreateLocks(),
quitCh: make(chan struct{}),

coreStateLock: &c.stateLock,
quitContext: c.requestContext,
}

if exp.logger == nil {
exp.logger = log.New("expiration_manager")
}

return exp
}

Expand All @@ -103,7 +111,7 @@ func (c *Core) setupExpiration() error {
view := c.systemBarrierView.SubView(expirationSubPath)

// Create the manager
mgr := NewExpirationManager(c.router, view, c.tokenStore, c.logger)
mgr := NewExpirationManager(c, view)
c.expiration = mgr

// Link the token store to this
Expand Down Expand Up @@ -430,14 +438,17 @@ func (m *ExpirationManager) Stop() error {
m.logger.Debug("expiration: stop triggered")
defer m.logger.Debug("expiration: finished stopping")

// Do this before stopping pending timers to avoid potential races with
// expiring timers
close(m.quitCh)

m.pendingLock.Lock()
for _, timer := range m.pending {
timer.Stop()
}
m.pending = make(map[string]*time.Timer)
m.pendingLock.Unlock()

close(m.quitCh)
if m.inRestoreMode() {
for {
if !m.inRestoreMode() {
Expand Down Expand Up @@ -969,13 +980,24 @@ func (m *ExpirationManager) expireID(leaseID string) {
return
default:
}

m.coreStateLock.RLock()
if m.quitContext.Err() == context.Canceled {
m.logger.Error("expiration: core context canceled, not attempting further revocation of lease", "lease_id", leaseID)
m.coreStateLock.RUnlock()
return
}

err := m.Revoke(leaseID)
if err == nil {
if m.logger.IsInfo() {
m.logger.Info("expiration: revoked lease", "lease_id", leaseID)
}
m.coreStateLock.RUnlock()
return
}

m.coreStateLock.RUnlock()
m.logger.Error("expiration: failed to revoke lease", "lease_id", leaseID, "error", err)
time.Sleep((1 << attempt) * revokeRetryBase)
}
Expand Down
2 changes: 1 addition & 1 deletion vault/request_forwarding.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package vault

import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
Expand All @@ -13,7 +14,6 @@ import (
"time"

"github.com/hashicorp/vault/helper/forwarding"
"golang.org/x/net/context"
"golang.org/x/net/http2"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
Expand Down

0 comments on commit 276a230

Please sign in to comment.