Skip to content

Commit

Permalink
client: manage TR kill from parent on SI token derivation failure
Browse files Browse the repository at this point in the history
Re-orient the management of the tr.kill to happen in the parent of
the spawned goroutine that is doing the actual token derivation. This
makes the code a little more straightforward, making it easier to
reason about not leaking the worker goroutine.
  • Loading branch information
shoenig committed Jan 15, 2020
1 parent 6146dd6 commit 63ce04e
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 14 deletions.
42 changes: 29 additions & 13 deletions client/allocrunner/taskrunner/sids_hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,23 +170,34 @@ func (h *sidsHook) recoverToken(dir string) (string, error) {
return string(token), nil
}

// siDerivationResult is used to pass along the result of attempting to derive
// an SI token between the goroutine doing the derivation and its caller
type siDerivationResult struct {
token string
err error
}

// deriveSIToken spawns and waits on a goroutine which will make attempts to
// derive an SI token until a token is successfully created, or ctx is signaled
// done.
func (h *sidsHook) deriveSIToken(ctx context.Context) (string, error) {
ctx2, cancel := context.WithTimeout(ctx, h.derivationTimeout)
defer cancel()

tokenCh := make(chan string)
resultCh := make(chan siDerivationResult)

// keep trying to get the token in the background
go h.tryDerive(ctx2, tokenCh)
go h.tryDerive(ctx2, resultCh)

// wait until we get a token, or we get a signal to quit
for {
select {
case token := <-tokenCh:
return token, nil
case result := <-resultCh:
if result.err != nil {
h.kill(ctx, errors.Wrap(result.err, "consul: failed to derive SI token"))
return "", result.err
}
return result.token, nil
case <-ctx2.Done():
return "", ctx2.Err()
}
Expand All @@ -203,31 +214,36 @@ func (h *sidsHook) kill(ctx context.Context, err error) {
}

// tryDerive loops forever until a token is created, or ctx is done.
func (h *sidsHook) tryDerive(ctx context.Context, ch chan<- string) {
func (h *sidsHook) tryDerive(ctx context.Context, ch chan<- siDerivationResult) {
for attempt := 0; backoff(ctx, attempt); attempt++ {

tokens, err := h.sidsClient.DeriveSITokens(h.alloc, []string{h.task.Name})

switch {

case err == nil:
// nothing broke and we can return the token for the task
ch <- tokens[h.task.Name]
token, exists := tokens[h.task.Name]
if !exists {
err := errors.New("response does not include token for task")
h.logger.Error("derive SI token is missing token for task", "error", err, "task", h.task.Name)
ch <- siDerivationResult{token: "", err: err}
return
}
ch <- siDerivationResult{token: token, err: nil}
return

case structs.IsServerSide(err):
// the error is known to be a server problem, just die
h.logger.Error("failed to derive SI token", "error", err, "task", h.task.Name, "server_side", true)
h.kill(ctx, errors.Wrap(err, "consul: failed to derive SI token"))

ch <- siDerivationResult{token: "", err: err}
return
case !structs.IsRecoverable(err):
// the error is known not to be recoverable, just die
h.logger.Error("failed to derive SI token", "error", err, "task", h.task.Name, "recoverable", false)
h.kill(ctx, errors.Wrap(err, "consul: failed to derive SI token"))
ch <- siDerivationResult{token: "", err: err}
return

default:
// the error is marked recoverable, retry after some backoff
h.logger.Error("failed to derive SI token", "error", err, "recoverable", true)
h.logger.Error("failed attempt to derive SI token", "error", err, "recoverable", true)
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion client/allocrunner/taskrunner/task_runner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1191,7 +1191,7 @@ func TestTaskRunner_DeriveSIToken_Retry(t *testing.T) {
trConfig, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
defer cleanup()

// control when we get a Consul SI token
// control when we get a Consul SI token (recoverable failure on first call)
token := uuid.Generate()
deriveCount := 0
deriveFn := func(*structs.Allocation, []string) (map[string]string, error) {
Expand Down

0 comments on commit 63ce04e

Please sign in to comment.