Skip to content

Commit

Permalink
Thread through whether DeriveToken error is recoverable or not
Browse files Browse the repository at this point in the history
  • Loading branch information
dadgar committed Oct 23, 2016
1 parent 0e296f4 commit 42f7bc8
Show file tree
Hide file tree
Showing 13 changed files with 389 additions and 105 deletions.
4 changes: 4 additions & 0 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -1714,6 +1714,10 @@ func (c *Client) deriveToken(alloc *structs.Allocation, taskNames []string, vcli
c.logger.Printf("[ERR] client.vault: failed to derive vault tokens: %v", err)
return nil, fmt.Errorf("failed to derive vault tokens: %v", err)
}
if resp.Error != nil {
c.logger.Printf("[ERR] client.vault: failed to derive vault tokens: %v", resp.Error)
return nil, resp.Error
}
if resp.Tasks == nil {
c.logger.Printf("[ERR] client.vault: failed to derive vault token: invalid response")
return nil, fmt.Errorf("failed to derive vault tokens: invalid response")
Expand Down
2 changes: 1 addition & 1 deletion client/driver/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ func (d *DockerDriver) recoverablePullError(err error, image string) error {
if imageNotFoundMatcher.MatchString(err.Error()) {
recoverable = false
}
return dstructs.NewRecoverableError(fmt.Errorf("Failed to pull `%s`: %s", image, err), recoverable)
return structs.NewRecoverableError(fmt.Errorf("Failed to pull `%s`: %s", image, err), recoverable)
}

func (d *DockerDriver) Periodic() (bool, time.Duration) {
Expand Down
20 changes: 0 additions & 20 deletions client/driver/structs/structs.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,26 +37,6 @@ func (r *WaitResult) String() string {
r.ExitCode, r.Signal, r.Err)
}

// RecoverableError wraps an error and marks whether it is recoverable and could
// be retried or it is fatal.
type RecoverableError struct {
Err error
Recoverable bool
}

// NewRecoverableError is used to wrap an error and mark it as recoverable or
// not.
func NewRecoverableError(e error, recoverable bool) *RecoverableError {
return &RecoverableError{
Err: e,
Recoverable: recoverable,
}
}

func (r *RecoverableError) Error() string {
return r.Err.Error()
}

// CheckResult encapsulates the result of a check
type CheckResult struct {

Expand Down
8 changes: 4 additions & 4 deletions client/restarts.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"sync"
"time"

cstructs "github.com/hashicorp/nomad/client/driver/structs"
dstructs "github.com/hashicorp/nomad/client/driver/structs"
"github.com/hashicorp/nomad/nomad/structs"
)

Expand Down Expand Up @@ -34,7 +34,7 @@ func newRestartTracker(policy *structs.RestartPolicy, jobType string) *RestartTr
}

type RestartTracker struct {
waitRes *cstructs.WaitResult
waitRes *dstructs.WaitResult
startErr error
restartTriggered bool // Whether the task has been signalled to be restarted
count int // Current number of attempts.
Expand Down Expand Up @@ -63,7 +63,7 @@ func (r *RestartTracker) SetStartError(err error) *RestartTracker {
}

// SetWaitResult is used to mark the most recent wait result.
func (r *RestartTracker) SetWaitResult(res *cstructs.WaitResult) *RestartTracker {
func (r *RestartTracker) SetWaitResult(res *dstructs.WaitResult) *RestartTracker {
r.lock.Lock()
defer r.lock.Unlock()
r.waitRes = res
Expand Down Expand Up @@ -149,7 +149,7 @@ func (r *RestartTracker) GetState() (string, time.Duration) {
// infinitely try to start a task.
func (r *RestartTracker) handleStartError() (string, time.Duration) {
// If the error is not recoverable, do not restart.
if rerr, ok := r.startErr.(*cstructs.RecoverableError); !(ok && rerr.Recoverable) {
if rerr, ok := r.startErr.(*structs.RecoverableError); !(ok && rerr.Recoverable) {
r.reason = ReasonUnrecoverableErrror
return structs.TaskNotRestarting, 0
}
Expand Down
4 changes: 2 additions & 2 deletions client/restarts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func TestClient_RestartTracker_StartError_Recoverable_Fail(t *testing.T) {
t.Parallel()
p := testPolicy(true, structs.RestartPolicyModeFail)
rt := newRestartTracker(p, structs.JobTypeSystem)
recErr := cstructs.NewRecoverableError(fmt.Errorf("foo"), true)
recErr := structs.NewRecoverableError(fmt.Errorf("foo"), true)
for i := 0; i < p.Attempts; i++ {
state, when := rt.SetStartError(recErr).GetState()
if state != structs.TaskRestarting {
Expand All @@ -129,7 +129,7 @@ func TestClient_RestartTracker_StartError_Recoverable_Delay(t *testing.T) {
t.Parallel()
p := testPolicy(true, structs.RestartPolicyModeDelay)
rt := newRestartTracker(p, structs.JobTypeSystem)
recErr := cstructs.NewRecoverableError(fmt.Errorf("foo"), true)
recErr := structs.NewRecoverableError(fmt.Errorf("foo"), true)
for i := 0; i < p.Attempts; i++ {
state, when := rt.SetStartError(recErr).GetState()
if state != structs.TaskRestarting {
Expand Down
27 changes: 18 additions & 9 deletions client/task_runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -509,10 +509,10 @@ OUTER:
// restoring the TaskRunner
if token == "" {
// Get a token
var ok bool
token, ok = r.deriveVaultToken()
if !ok {
// We are shutting down
var exit bool
token, exit = r.deriveVaultToken()
if exit {
// Exit the manager
return
}

Expand Down Expand Up @@ -589,27 +589,36 @@ OUTER:
// deriveVaultToken derives the Vault token using exponential backoffs. It
// returns the Vault token and whether the token is valid. If it is not valid we
// are shutting down
func (r *TaskRunner) deriveVaultToken() (string, bool) {
func (r *TaskRunner) deriveVaultToken() (token string, exit bool) {
attempts := 0
for {
tokens, err := r.vaultClient.DeriveToken(r.alloc, []string{r.task.Name})
if err == nil {
return tokens[r.task.Name], true
return tokens[r.task.Name], false
}

// Check if we can't recover from the error
if rerr, ok := err.(*structs.RecoverableError); !ok || !rerr.Recoverable {
r.logger.Printf("[ERR] client: failed to derive Vault token for task %v on alloc %q: %v",
r.task.Name, r.alloc.ID, err)
r.Kill("vault", fmt.Sprintf("failed to derive token: %v", err))
return "", true
}

// Handle the retry case
backoff := (1 << (2 * uint64(attempts))) * vaultBackoffBaseline
if backoff > vaultBackoffLimit {
backoff = vaultBackoffLimit
}
r.logger.Printf("[ERR] client: failed to derive Vault token for task %v on alloc %q: %v; retrying in %v", r.task.Name, r.alloc.ID, err, backoff)
r.logger.Printf("[ERR] client: failed to derive Vault token for task %v on alloc %q: %v; retrying in %v",
r.task.Name, r.alloc.ID, err, backoff)

attempts++

// Wait till retrying
select {
case <-r.waitCh:
return "", false
return "", true
case <-time.After(backoff):
}
}
Expand Down Expand Up @@ -706,7 +715,7 @@ func (r *TaskRunner) prestart(resultCh chan bool) {
if err := getter.GetArtifact(r.getTaskEnv(), artifact, r.taskDir); err != nil {
r.setState(structs.TaskStatePending,
structs.NewTaskEvent(structs.TaskArtifactDownloadFailed).SetDownloadError(err))
r.restartTracker.SetStartError(dstructs.NewRecoverableError(err, true))
r.restartTracker.SetStartError(structs.NewRecoverableError(err, true))
goto RESTART
}
}
Expand Down
45 changes: 44 additions & 1 deletion client/task_runner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -721,7 +721,7 @@ func TestTaskRunner_DeriveToken_Retry(t *testing.T) {
}

count++
return nil, fmt.Errorf("Want a retry")
return nil, structs.NewRecoverableError(fmt.Errorf("Want a retry"), true)
}
tr.vaultClient.(*vaultclient.MockVaultClient).DeriveTokenFn = handler
go tr.Run()
Expand Down Expand Up @@ -770,6 +770,49 @@ func TestTaskRunner_DeriveToken_Retry(t *testing.T) {
}
}

func TestTaskRunner_DeriveToken_Unrecoverable(t *testing.T) {
alloc := mock.Alloc()
task := alloc.Job.TaskGroups[0].Tasks[0]
task.Driver = "mock_driver"
task.Config = map[string]interface{}{
"exit_code": "0",
"run_for": "10s",
}
task.Vault = &structs.Vault{
Policies: []string{"default"},
ChangeMode: structs.VaultChangeModeRestart,
}

upd, tr := testTaskRunnerFromAlloc(false, alloc)
tr.MarkReceived()
defer tr.Destroy(structs.NewTaskEvent(structs.TaskKilled))
defer tr.ctx.AllocDir.Destroy()

// Error the token derivation
vc := tr.vaultClient.(*vaultclient.MockVaultClient)
vc.SetDeriveTokenError(alloc.ID, []string{task.Name}, fmt.Errorf("Non recoverable"))
go tr.Run()

// Wait for the task to start
testutil.WaitForResult(func() (bool, error) {
if l := len(upd.events); l != 2 {
return false, fmt.Errorf("Expect two events; got %v", l)
}

if upd.events[0].Type != structs.TaskReceived {
return false, fmt.Errorf("First Event was %v; want %v", upd.events[0].Type, structs.TaskReceived)
}

if upd.events[1].Type != structs.TaskKilling {
return false, fmt.Errorf("Second Event was %v; want %v", upd.events[1].Type, structs.TaskKilling)
}

return true, nil
}, func(err error) {
t.Fatalf("err: %v", err)
})
}

func TestTaskRunner_Template_Block(t *testing.T) {
alloc := mock.Alloc()
task := alloc.Job.TaskGroups[0].Tasks[0]
Expand Down
Loading

0 comments on commit 42f7bc8

Please sign in to comment.