diff --git a/api/tasks.go b/api/tasks.go index dfb57bddf03..c8e70d89542 100644 --- a/api/tasks.go +++ b/api/tasks.go @@ -233,6 +233,8 @@ const ( TaskDownloadingArtifacts = "Downloading Artifacts" TaskArtifactDownloadFailed = "Failed Artifact Download" TaskDiskExceeded = "Disk Exceeded" + TaskVaultRenewalFailed = "Vault token renewal failed" + TaskSiblingFailed = "Sibling task failed" ) // TaskEvent is an event that effects the state of a task and contains meta-data @@ -250,4 +252,8 @@ type TaskEvent struct { StartDelay int64 DownloadError string ValidationError string + DiskLimit int64 + DiskSize int64 + FailedSibling string + VaultError string } diff --git a/client/alloc_runner.go b/client/alloc_runner.go index 00d76746422..f84ffbb039b 100644 --- a/client/alloc_runner.go +++ b/client/alloc_runner.go @@ -2,6 +2,7 @@ package client import ( "fmt" + "io/ioutil" "log" "os" "path/filepath" @@ -12,6 +13,7 @@ import ( "github.com/hashicorp/nomad/client/allocdir" "github.com/hashicorp/nomad/client/config" "github.com/hashicorp/nomad/client/driver" + "github.com/hashicorp/nomad/client/vaultclient" "github.com/hashicorp/nomad/nomad/structs" cstructs "github.com/hashicorp/nomad/client/structs" @@ -29,6 +31,10 @@ const ( // watchdogInterval is the interval at which resource constraints for the // allocation are being checked and enforced. watchdogInterval = 5 * time.Second + + // vaultTokenFile is the name of the file holding the Vault token inside the + // task's secret directory + vaultTokenFile = "vault_token" ) // AllocStateUpdater is used to update the status of an allocation @@ -62,6 +68,9 @@ type AllocRunner struct { updateCh chan *structs.Allocation + vaultClient vaultclient.VaultClient + vaultTokens map[string]vaultToken + destroy bool destroyCh chan struct{} destroyLock sync.Mutex @@ -82,19 +91,20 @@ type allocRunnerState struct { // NewAllocRunner is used to create a new allocation context func NewAllocRunner(logger *log.Logger, config *config.Config, updater AllocStateUpdater, - alloc *structs.Allocation) *AllocRunner { + alloc *structs.Allocation, vaultClient vaultclient.VaultClient) *AllocRunner { ar := &AllocRunner{ - config: config, - updater: updater, - logger: logger, - alloc: alloc, - dirtyCh: make(chan struct{}, 1), - tasks: make(map[string]*TaskRunner), - taskStates: copyTaskStates(alloc.TaskStates), - restored: make(map[string]struct{}), - updateCh: make(chan *structs.Allocation, 64), - destroyCh: make(chan struct{}), - waitCh: make(chan struct{}), + config: config, + updater: updater, + logger: logger, + alloc: alloc, + dirtyCh: make(chan struct{}, 1), + tasks: make(map[string]*TaskRunner), + taskStates: copyTaskStates(alloc.TaskStates), + restored: make(map[string]struct{}), + updateCh: make(chan *structs.Allocation, 64), + destroyCh: make(chan struct{}), + waitCh: make(chan struct{}), + vaultClient: vaultClient, } return ar } @@ -133,6 +143,9 @@ func (r *AllocRunner) RestoreState() error { return e } + // Recover the Vault tokens + vaultErr := r.recoverVaultTokens() + // Restore the task runners var mErr multierror.Error for name, state := range r.taskStates { @@ -144,6 +157,10 @@ func (r *AllocRunner) RestoreState() error { task) r.tasks[name] = tr + if vt, ok := r.vaultTokens[name]; ok { + tr.SetVaultToken(vt.token, vt.renewalCh) + } + // Skip tasks in terminal states. if state.State == structs.TaskStateDead { continue @@ -157,6 +174,21 @@ func (r *AllocRunner) RestoreState() error { go tr.Run() } } + + // Since this is somewhat of an expected case we do not return an error but + // handle it gracefully. + if vaultErr != nil { + msg := fmt.Sprintf("failed to recover Vault tokens for allocation %q: %v", r.alloc.ID, vaultErr) + r.logger.Printf("[ERR] client: %s", msg) + r.setStatus(structs.AllocClientStatusFailed, msg) + + // Destroy the task runners and set the error + r.destroyTaskRunners(structs.NewTaskEvent(structs.TaskVaultRenewalFailed).SetVaultRenewalError(vaultErr)) + + // Handle cleanup + go r.handleDestroy() + } + return mErr.ErrorOrNil() } @@ -333,17 +365,26 @@ func (r *AllocRunner) setTaskState(taskName, state string, event *structs.TaskEv taskState.State = state r.appendTaskEvent(taskState, event) - // If the task failed, we should kill all the other tasks in the task group. - if state == structs.TaskStateDead && taskState.Failed() { - var destroyingTasks []string - for task, tr := range r.tasks { - if task != taskName { - destroyingTasks = append(destroyingTasks, task) - tr.Destroy(structs.NewTaskEvent(structs.TaskSiblingFailed).SetFailedSibling(taskName)) + if state == structs.TaskStateDead { + // If the task has a Vault token, stop renewing it + if vt, ok := r.vaultTokens[taskName]; ok { + if err := r.vaultClient.StopRenewToken(vt.token); err != nil { + r.logger.Printf("[ERR] client: stopping token renewal for task %q failed: %v", taskName, err) } } - if len(destroyingTasks) > 0 { - r.logger.Printf("[DEBUG] client: task %q failed, destroying other tasks in task group: %v", taskName, destroyingTasks) + + // If the task failed, we should kill all the other tasks in the task group. + if taskState.Failed() { + var destroyingTasks []string + for task, tr := range r.tasks { + if task != taskName { + destroyingTasks = append(destroyingTasks, task) + tr.Destroy(structs.NewTaskEvent(structs.TaskSiblingFailed).SetFailedSibling(taskName)) + } + } + if len(destroyingTasks) > 0 { + r.logger.Printf("[DEBUG] client: task %q failed, destroying other tasks in task group: %v", taskName, destroyingTasks) + } } } @@ -408,6 +449,15 @@ func (r *AllocRunner) Run() { return } + // Request Vault tokens for the tasks that require them + err := r.deriveVaultTokens() + if err != nil { + msg := fmt.Sprintf("failed to derive Vault token for allocation %q: %v", r.alloc.ID, err) + r.logger.Printf("[ERR] client: %s", msg) + r.setStatus(structs.AllocClientStatusFailed, msg) + return + } + // Start the task runners r.logger.Printf("[DEBUG] client: starting task runners for alloc '%s'", r.alloc.ID) r.taskLock.Lock() @@ -416,10 +466,15 @@ func (r *AllocRunner) Run() { continue } - tr := NewTaskRunner(r.logger, r.config, r.setTaskState, r.ctx, r.Alloc(), - task.Copy()) + tr := NewTaskRunner(r.logger, r.config, r.setTaskState, r.ctx, r.Alloc(), task.Copy()) r.tasks[task.Name] = tr tr.MarkReceived() + + // If the task has a vault token set it before running + if vt, ok := r.vaultTokens[task.Name]; ok { + tr.SetVaultToken(vt.token, vt.renewalCh) + } + go tr.Run() } r.taskLock.Unlock() @@ -467,10 +522,24 @@ OUTER: } } + // Kill the task runners + r.destroyTaskRunners(taskDestroyEvent) + + // Stop watching the shared allocation directory + r.ctx.AllocDir.StopDiskWatcher() + + // Block until we should destroy the state of the alloc + r.handleDestroy() + r.logger.Printf("[DEBUG] client: terminating runner for alloc '%s'", r.alloc.ID) +} + +// destroyTaskRunners destroys the task runners, waits for them to terminate and +// then saves state. +func (r *AllocRunner) destroyTaskRunners(destroyEvent *structs.TaskEvent) { // Destroy each sub-task runners := r.getTaskRunners() for _, tr := range runners { - tr.Destroy(taskDestroyEvent) + tr.Destroy(destroyEvent) } // Wait for termination of the task runners @@ -480,13 +549,149 @@ OUTER: // Final state sync r.syncStatus() +} - // Stop watching the shared allocation directory - r.ctx.AllocDir.StopDiskWatcher() +// vaultToken acts as a tuple of the token and renewal channel +type vaultToken struct { + token string + renewalCh <-chan error +} - // Block until we should destroy the state of the alloc - r.handleDestroy() - r.logger.Printf("[DEBUG] client: terminating runner for alloc '%s'", r.alloc.ID) +// deriveVaultTokens derives the required vault tokens and returns a map of the +// tasks to their respective vault token and renewal channel. This must be +// called after the allocation directory is created as the vault tokens are +// written to disk. +func (r *AllocRunner) deriveVaultTokens() error { + required, err := r.tasksRequiringVaultTokens() + if err != nil { + return err + } + + if len(required) == 0 { + return nil + } + + if r.vaultTokens == nil { + r.vaultTokens = make(map[string]vaultToken, len(required)) + } + + // Get the tokens + tokens, err := r.vaultClient.DeriveToken(r.Alloc(), required) + if err != nil { + return fmt.Errorf("failed to derive Vault tokens: %v", err) + } + + // Persist the tokens to the appropriate secret directories + adir := r.ctx.AllocDir + for task, token := range tokens { + // Has been recovered + if _, ok := r.vaultTokens[task]; ok { + continue + } + + secretDir, err := adir.GetSecretDir(task) + if err != nil { + return fmt.Errorf("failed to determine task %s secret dir in alloc %q: %v", task, r.alloc.ID, err) + } + + // Write the token to the file system + tokenPath := filepath.Join(secretDir, vaultTokenFile) + if err := ioutil.WriteFile(tokenPath, []byte(token), 0777); err != nil { + return fmt.Errorf("failed to save Vault tokens to secret dir for task %q in alloc %q: %v", task, r.alloc.ID, err) + } + + // Start renewing the token + renewCh, err := r.vaultClient.RenewToken(token, 10) + if err != nil { + var mErr multierror.Error + errMsg := fmt.Errorf("failed to renew Vault token for task %q in alloc %q: %v", task, r.alloc.ID, err) + multierror.Append(&mErr, errMsg) + + // Clean up any token that we have started renewing + for _, token := range r.vaultTokens { + if err := r.vaultClient.StopRenewToken(token.token); err != nil { + multierror.Append(&mErr, err) + } + } + + return mErr.ErrorOrNil() + } + r.vaultTokens[task] = vaultToken{token: token, renewalCh: renewCh} + } + + return nil +} + +// tasksRequiringVaultTokens returns the set of tasks that require a Vault token +func (r *AllocRunner) tasksRequiringVaultTokens() ([]string, error) { + // Get the tasks + tg := r.alloc.Job.LookupTaskGroup(r.alloc.TaskGroup) + if tg == nil { + return nil, fmt.Errorf("Failed to lookup task group in alloc") + } + + // Retrieve any required Vault tokens + var required []string + for _, task := range tg.Tasks { + if task.Vault != nil && len(task.Vault.Policies) != 0 { + required = append(required, task.Name) + } + } + + return required, nil +} + +// recoverVaultTokens reads the Vault tokens for the tasks that have Vault +// tokens off disk. If there is an error, it is returned, otherwise token +// renewal is started. +func (r *AllocRunner) recoverVaultTokens() error { + required, err := r.tasksRequiringVaultTokens() + if err != nil { + return err + } + + if len(required) == 0 { + return nil + } + + // Read the tokens and start renewing them + adir := r.ctx.AllocDir + renewingTokens := make(map[string]vaultToken, len(required)) + for _, task := range required { + secretDir, err := adir.GetSecretDir(task) + if err != nil { + return fmt.Errorf("failed to determine task %s secret dir in alloc %q: %v", task, r.alloc.ID, err) + } + + // Read the token from the secret directory + tokenPath := filepath.Join(secretDir, vaultTokenFile) + data, err := ioutil.ReadFile(tokenPath) + if err != nil { + return fmt.Errorf("failed to read token for task %q in alloc %q: %v", task, r.alloc.ID, err) + } + + token := string(data) + renewCh, err := r.vaultClient.RenewToken(token, 10) + if err != nil { + var mErr multierror.Error + errMsg := fmt.Errorf("failed to renew Vault token for task %q in alloc %q: %v", task, r.alloc.ID, err) + multierror.Append(&mErr, errMsg) + + // Clean up any token that we have started renewing + for _, token := range renewingTokens { + if err := r.vaultClient.StopRenewToken(token.token); err != nil { + multierror.Append(&mErr, err) + } + } + + return mErr.ErrorOrNil() + } + + renewingTokens[task] = vaultToken{token: token, renewalCh: renewCh} + } + + r.vaultTokens = renewingTokens + return nil } // checkResources monitors and enforces alloc resource usage. It returns an diff --git a/client/alloc_runner_test.go b/client/alloc_runner_test.go index cbc0c12aa40..e848596ec4d 100644 --- a/client/alloc_runner_test.go +++ b/client/alloc_runner_test.go @@ -3,7 +3,9 @@ package client import ( "bufio" "fmt" + "io/ioutil" "os" + "path/filepath" "testing" "time" @@ -13,6 +15,7 @@ import ( "github.com/hashicorp/nomad/client/config" ctestutil "github.com/hashicorp/nomad/client/testutil" + "github.com/hashicorp/nomad/client/vaultclient" ) type MockAllocStateUpdater struct { @@ -35,7 +38,8 @@ func testAllocRunnerFromAlloc(alloc *structs.Allocation, restarts bool) (*MockAl *alloc.Job.LookupTaskGroup(alloc.TaskGroup).RestartPolicy = structs.RestartPolicy{Attempts: 0} alloc.Job.Type = structs.JobTypeBatch } - ar := NewAllocRunner(logger, conf, upd.Update, alloc) + vclient := vaultclient.NewMockVaultClient() + ar := NewAllocRunner(logger, conf, upd.Update, alloc, vclient) return upd, ar } @@ -324,7 +328,7 @@ func TestAllocRunner_Destroy(t *testing.T) { // Begin the tear down go func() { - time.Sleep(100 * time.Millisecond) + time.Sleep(1 * time.Second) ar.Destroy() }() @@ -390,13 +394,15 @@ func TestAllocRunner_Update(t *testing.T) { } func TestAllocRunner_SaveRestoreState(t *testing.T) { - ctestutil.ExecCompatible(t) - upd, ar := testAllocRunner(false) + alloc := mock.Alloc() + task := alloc.Job.TaskGroups[0].Tasks[0] + task.Driver = "mock_driver" + task.Config = map[string]interface{}{ + "exit_code": "0", + "run_for": "10s", + } - // Ensure task takes some time - task := ar.alloc.Job.TaskGroups[0].Tasks[0] - task.Config["command"] = "/bin/sleep" - task.Config["args"] = []string{"10"} + upd, ar := testAllocRunnerFromAlloc(alloc, false) go ar.Run() // Snapshot state @@ -413,28 +419,43 @@ func TestAllocRunner_SaveRestoreState(t *testing.T) { // Create a new alloc runner ar2 := NewAllocRunner(ar.logger, ar.config, upd.Update, - &structs.Allocation{ID: ar.alloc.ID}) + &structs.Allocation{ID: ar.alloc.ID}, ar.vaultClient) err = ar2.RestoreState() if err != nil { t.Fatalf("err: %v", err) } go ar2.Run() + testutil.WaitForResult(func() (bool, error) { + if len(ar2.tasks) != 1 { + return false, fmt.Errorf("Incorrect number of tasks") + } + + if upd.Count == 0 { + return false, nil + } + + last := upd.Allocs[upd.Count-1] + return last.ClientStatus == structs.AllocClientStatusRunning, nil + }, func(err error) { + t.Fatalf("err: %v %#v %#v", err, upd.Allocs[0], ar.alloc.TaskStates) + }) + // Destroy and wait ar2.Destroy() start := time.Now() testutil.WaitForResult(func() (bool, error) { - if upd.Count == 0 { - return false, nil + alloc := ar2.Alloc() + if alloc.ClientStatus != structs.AllocClientStatusComplete { + return false, fmt.Errorf("Bad client status; got %v; want %v", alloc.ClientStatus, structs.AllocClientStatusComplete) } - last := upd.Allocs[upd.Count-1] - return last.ClientStatus != structs.AllocClientStatusPending, nil + return true, nil }, func(err error) { t.Fatalf("err: %v %#v %#v", err, upd.Allocs[0], ar.alloc.TaskStates) }) - if time.Since(start) > time.Duration(testutil.TestMultiplier()*15)*time.Second { + if time.Since(start) > time.Duration(testutil.TestMultiplier()*5)*time.Second { t.Fatalf("took too long to terminate") } } @@ -486,7 +507,7 @@ func TestAllocRunner_SaveRestoreState_TerminalAlloc(t *testing.T) { // Create a new alloc runner ar2 := NewAllocRunner(ar.logger, ar.config, upd.Update, - &structs.Allocation{ID: ar.alloc.ID}) + &structs.Allocation{ID: ar.alloc.ID}, ar.vaultClient) ar2.logger = prefixedTestLogger("ar2: ") err = ar2.RestoreState() if err != nil { @@ -577,7 +598,10 @@ func TestAllocRunner_TaskFailed_KillTG(t *testing.T) { if state1.State != structs.TaskStateDead { return false, fmt.Errorf("got state %v; want %v", state1.State, structs.TaskStateDead) } - if lastE := state1.Events[len(state1.Events)-1]; lastE.Type != structs.TaskSiblingFailed { + if len(state1.Events) < 3 { + return false, fmt.Errorf("Unexpected number of events") + } + if lastE := state1.Events[len(state1.Events)-3]; lastE.Type != structs.TaskSiblingFailed { return false, fmt.Errorf("got last event %v; want %v", lastE.Type, structs.TaskSiblingFailed) } @@ -595,3 +619,245 @@ func TestAllocRunner_TaskFailed_KillTG(t *testing.T) { t.Fatalf("err: %v", err) }) } + +func TestAllocRunner_SimpleRun_VaultToken(t *testing.T) { + alloc := mock.Alloc() + task := alloc.Job.TaskGroups[0].Tasks[0] + task.Driver = "mock_driver" + task.Config = map[string]interface{}{"exit_code": "0"} + task.Vault = &structs.Vault{ + Policies: []string{"default"}, + } + + upd, ar := testAllocRunnerFromAlloc(alloc, false) + go ar.Run() + defer ar.Destroy() + + testutil.WaitForResult(func() (bool, error) { + if upd.Count == 0 { + return false, fmt.Errorf("No updates") + } + last := upd.Allocs[upd.Count-1] + if last.ClientStatus != structs.AllocClientStatusComplete { + return false, fmt.Errorf("got status %v; want %v", last.ClientStatus, structs.AllocClientStatusComplete) + } + return true, nil + }, func(err error) { + t.Fatalf("err: %v", err) + }) + + tr, ok := ar.tasks[task.Name] + if !ok { + t.Fatalf("No task runner made") + } + + // Check that the task runner was given the token + token := tr.vaultToken + if token == "" || tr.vaultRenewalCh == nil { + t.Fatalf("Vault token not set properly") + } + + // Check that it was written to disk + secretDir, err := ar.ctx.AllocDir.GetSecretDir(task.Name) + if err != nil { + t.Fatalf("bad: %v", err) + } + + tokenPath := filepath.Join(secretDir, vaultTokenFile) + data, err := ioutil.ReadFile(tokenPath) + if err != nil { + t.Fatalf("token not written to disk: %v", err) + } + + if string(data) != token { + t.Fatalf("Bad token written to disk") + } + + // Check that we stopped renewing the token + mockVC := ar.vaultClient.(*vaultclient.MockVaultClient) + if len(mockVC.StoppedTokens) != 1 || mockVC.StoppedTokens[0] != token { + t.Fatalf("We didn't stop renewing the token") + } +} + +func TestAllocRunner_SaveRestoreState_VaultTokens_Valid(t *testing.T) { + alloc := mock.Alloc() + task := alloc.Job.TaskGroups[0].Tasks[0] + task.Driver = "mock_driver" + task.Config = map[string]interface{}{ + "exit_code": "0", + "run_for": "10s", + } + task.Vault = &structs.Vault{ + Policies: []string{"default"}, + } + + upd, ar := testAllocRunnerFromAlloc(alloc, false) + go ar.Run() + + // Snapshot state + var token string + testutil.WaitForResult(func() (bool, error) { + if len(ar.tasks) != 1 { + return false, fmt.Errorf("Task not started") + } + + tr, ok := ar.tasks[task.Name] + if !ok { + return false, fmt.Errorf("Incorrect task runner") + } + + if tr.vaultToken == "" { + return false, fmt.Errorf("Bad token") + } + + token = tr.vaultToken + return true, nil + }, func(err error) { + t.Fatalf("task never started: %v", err) + }) + + err := ar.SaveState() + if err != nil { + t.Fatalf("err: %v", err) + } + + // Create a new alloc runner + ar2 := NewAllocRunner(ar.logger, ar.config, upd.Update, + &structs.Allocation{ID: ar.alloc.ID}, ar.vaultClient) + err = ar2.RestoreState() + if err != nil { + t.Fatalf("err: %v", err) + } + go ar2.Run() + + testutil.WaitForResult(func() (bool, error) { + if len(ar2.tasks) != 1 { + return false, fmt.Errorf("Incorrect number of tasks") + } + + tr, ok := ar2.tasks[task.Name] + if !ok { + return false, fmt.Errorf("Incorrect task runner") + } + + if tr.vaultToken != token { + return false, fmt.Errorf("Got token %q; want %q", tr.vaultToken, token) + } + + if upd.Count == 0 { + return false, nil + } + + last := upd.Allocs[upd.Count-1] + return last.ClientStatus == structs.AllocClientStatusRunning, nil + }, func(err error) { + t.Fatalf("err: %v %#v %#v", err, upd.Allocs[0], ar.alloc.TaskStates) + }) + + // Destroy and wait + ar2.Destroy() + start := time.Now() + + testutil.WaitForResult(func() (bool, error) { + alloc := ar2.Alloc() + if alloc.ClientStatus != structs.AllocClientStatusComplete { + return false, fmt.Errorf("Bad client status; got %v; want %v", alloc.ClientStatus, structs.AllocClientStatusComplete) + } + return true, nil + }, func(err error) { + t.Fatalf("err: %v %#v %#v", err, upd.Allocs[0], ar.alloc.TaskStates) + }) + + if time.Since(start) > time.Duration(testutil.TestMultiplier()*5)*time.Second { + t.Fatalf("took too long to terminate") + } +} + +func TestAllocRunner_SaveRestoreState_VaultTokens_Invalid(t *testing.T) { + alloc := mock.Alloc() + task := alloc.Job.TaskGroups[0].Tasks[0] + task.Driver = "mock_driver" + task.Config = map[string]interface{}{ + "exit_code": "0", + "run_for": "10s", + } + task.Vault = &structs.Vault{ + Policies: []string{"default"}, + } + + upd, ar := testAllocRunnerFromAlloc(alloc, false) + go ar.Run() + + // Snapshot state + var token string + testutil.WaitForResult(func() (bool, error) { + if len(ar.tasks) != 1 { + return false, fmt.Errorf("Task not started") + } + + tr, ok := ar.tasks[task.Name] + if !ok { + return false, fmt.Errorf("Incorrect task runner") + } + + if tr.vaultToken == "" { + return false, fmt.Errorf("Bad token") + } + + token = tr.vaultToken + return true, nil + }, func(err error) { + t.Fatalf("task never started: %v", err) + }) + + err := ar.SaveState() + if err != nil { + t.Fatalf("err: %v", err) + } + + // Create a new alloc runner + ar2 := NewAllocRunner(ar.logger, ar.config, upd.Update, + &structs.Allocation{ID: ar.alloc.ID}, ar.vaultClient) + + // Invalidate the token + mockVC := ar2.vaultClient.(*vaultclient.MockVaultClient) + renewErr := fmt.Errorf("Test disallowing renewal") + mockVC.SetRenewTokenError(token, renewErr) + + // Restore and run + err = ar2.RestoreState() + if err != nil { + t.Fatalf("err: %v", err) + } + go ar2.Run() + + testutil.WaitForResult(func() (bool, error) { + if upd.Count == 0 { + return false, nil + } + + last := upd.Allocs[upd.Count-1] + return last.ClientStatus == structs.AllocClientStatusFailed, nil + }, func(err error) { + t.Fatalf("err: %v %#v %#v", err, upd.Allocs[0], ar.alloc.TaskStates) + }) + + // Destroy and wait + ar2.Destroy() + start := time.Now() + + testutil.WaitForResult(func() (bool, error) { + alloc := ar2.Alloc() + if alloc.ClientStatus != structs.AllocClientStatusFailed { + return false, fmt.Errorf("Bad client status; got %v; want %v", alloc.ClientStatus, structs.AllocClientStatusFailed) + } + return true, nil + }, func(err error) { + t.Fatalf("err: %v %#v %#v", err, upd.Allocs[0], ar.alloc.TaskStates) + }) + + if time.Since(start) > time.Duration(testutil.TestMultiplier()*5)*time.Second { + t.Fatalf("took too long to terminate") + } +} diff --git a/client/allocdir/alloc_dir.go b/client/allocdir/alloc_dir.go index 312bb5d797e..981e6db9bd5 100644 --- a/client/allocdir/alloc_dir.go +++ b/client/allocdir/alloc_dir.go @@ -552,3 +552,11 @@ func (d *AllocDir) syncDiskUsage() error { d.setSize(size) return err } + +func (d *AllocDir) GetSecretDir(task string) (string, error) { + if t, ok := d.TaskDirs[task]; !ok { + return "", fmt.Errorf("Allocation directory doesn't contain task %q", task) + } else { + return filepath.Join(t, TaskSecrets), nil + } +} diff --git a/client/client.go b/client/client.go index 76b5748f8b6..58b1e8e3426 100644 --- a/client/client.go +++ b/client/client.go @@ -150,7 +150,7 @@ type Client struct { shutdownCh chan struct{} shutdownLock sync.Mutex - // client to interact with vault for token and secret renewals + // vaultClient is used to interact with Vault for token and secret renewals vaultClient vaultclient.VaultClient } @@ -208,11 +208,6 @@ func NewClient(cfg *config.Config, consulSyncer *consul.Syncer, logger *log.Logg } c.configLock.RUnlock() - // Restore the state - if err := c.restoreState(); err != nil { - return nil, fmt.Errorf("failed to restore state: %v", err) - } - // Setup the Consul syncer if err := c.setupConsulSyncer(); err != nil { return nil, fmt.Errorf("failed to create client Consul syncer: %v", err) @@ -223,6 +218,11 @@ func NewClient(cfg *config.Config, consulSyncer *consul.Syncer, logger *log.Logg return nil, fmt.Errorf("failed to setup vault client: %v", err) } + // Restore the state + if err := c.restoreState(); err != nil { + return nil, fmt.Errorf("failed to restore state: %v", err) + } + // Register and then start heartbeating to the servers. go c.registerAndHeartbeat() @@ -248,11 +248,6 @@ func NewClient(cfg *config.Config, consulSyncer *consul.Syncer, logger *log.Logg // populated by periodically polling Consul, if available. go c.rpcProxy.Run() - // Start renewing tokens and secrets - if c.vaultClient != nil { - c.vaultClient.Start() - } - return c, nil } @@ -469,7 +464,7 @@ func (c *Client) restoreState() error { id := entry.Name() alloc := &structs.Allocation{ID: id} c.configLock.RLock() - ar := NewAllocRunner(c.logger, c.configCopy, c.updateAllocStatus, alloc) + ar := NewAllocRunner(c.logger, c.configCopy, c.updateAllocStatus, alloc, c.vaultClient) c.configLock.RUnlock() c.allocLock.Lock() c.allocs[id] = ar @@ -1285,7 +1280,7 @@ func (c *Client) updateAlloc(exist, update *structs.Allocation) error { // addAlloc is invoked when we should add an allocation func (c *Client) addAlloc(alloc *structs.Allocation) error { c.configLock.RLock() - ar := NewAllocRunner(c.logger, c.configCopy, c.updateAllocStatus, alloc) + ar := NewAllocRunner(c.logger, c.configCopy, c.updateAllocStatus, alloc, c.vaultClient) c.configLock.RUnlock() go ar.Run() @@ -1299,14 +1294,6 @@ func (c *Client) addAlloc(alloc *structs.Allocation) error { // setupVaultClient creates an object to periodically renew tokens and secrets // with vault. func (c *Client) setupVaultClient() error { - if c.config.VaultConfig == nil { - return fmt.Errorf("nil vault config") - } - - if !c.config.VaultConfig.Enabled { - return nil - } - var err error if c.vaultClient, err = vaultclient.NewVaultClient(c.config.VaultConfig, c.logger, c.deriveToken); err != nil { @@ -1318,6 +1305,9 @@ func (c *Client) setupVaultClient() error { return fmt.Errorf("failed to create vault client") } + // Start renewing tokens and secrets + c.vaultClient.Start() + return nil } diff --git a/client/driver/docker.go b/client/driver/docker.go index c2fdc92c2a5..172117d3e3a 100644 --- a/client/driver/docker.go +++ b/client/driver/docker.go @@ -348,17 +348,24 @@ func (d *DockerDriver) containerBinds(alloc *allocdir.AllocDir, task *structs.Ta if !ok { return nil, fmt.Errorf("Failed to find task local directory: %v", task.Name) } + secret, err := alloc.GetSecretDir(task.Name) + if err != nil { + return nil, err + } allocDirBind := fmt.Sprintf("%s:%s", shared, allocdir.SharedAllocContainerPath) taskLocalBind := fmt.Sprintf("%s:%s", local, allocdir.TaskLocalContainerPath) + secretDirBind := fmt.Sprintf("%s:%s", secret, allocdir.TaskSecretsContainerPath) if selinuxLabel := d.config.Read("docker.volumes.selinuxlabel"); selinuxLabel != "" { allocDirBind = fmt.Sprintf("%s:%s", allocDirBind, selinuxLabel) taskLocalBind = fmt.Sprintf("%s:%s", taskLocalBind, selinuxLabel) + secretDirBind = fmt.Sprintf("%s:%s", secretDirBind, selinuxLabel) } return []string{ allocDirBind, taskLocalBind, + secretDirBind, }, nil } diff --git a/client/driver/driver.go b/client/driver/driver.go index 60adf9c7d77..ac152a50801 100644 --- a/client/driver/driver.go +++ b/client/driver/driver.go @@ -135,7 +135,7 @@ func NewExecContext(alloc *allocdir.AllocDir, allocID string) *ExecContext { // GetTaskEnv converts the alloc dir, the node, task and alloc into a // TaskEnvironment. func GetTaskEnv(allocDir *allocdir.AllocDir, node *structs.Node, - task *structs.Task, alloc *structs.Allocation) (*env.TaskEnvironment, error) { + task *structs.Task, alloc *structs.Allocation, vaultToken string) (*env.TaskEnvironment, error) { tg := alloc.Job.LookupTaskGroup(alloc.TaskGroup) env := env.NewTaskEnvironment(node). @@ -166,6 +166,9 @@ func GetTaskEnv(allocDir *allocdir.AllocDir, node *structs.Node, env.SetAlloc(alloc) } + // TODO: make this conditional on the task's vault block allowing it + env.SetVaultToken(vaultToken, true) + return env.Build(), nil } diff --git a/client/driver/env/env.go b/client/driver/env/env.go index e2ff660b89f..bd74e79fc25 100644 --- a/client/driver/env/env.go +++ b/client/driver/env/env.go @@ -60,6 +60,9 @@ const ( // MetaPrefix is the prefix for passing task meta data. MetaPrefix = "NOMAD_META_" + + // VaultToken is the environment variable for passing the Vault token + VaultToken = "VAULT_TOKEN" ) // The node values that can be interpreted. @@ -77,22 +80,24 @@ const ( // TaskEnvironment is used to expose information to a task via environment // variables and provide interpolation of Nomad variables. type TaskEnvironment struct { - Env map[string]string - TaskMeta map[string]string - TaskGroupMeta map[string]string - JobMeta map[string]string - AllocDir string - TaskDir string - SecretDir string - CpuLimit int - MemLimit int - TaskName string - AllocIndex int - AllocId string - AllocName string - Node *structs.Node - Networks []*structs.NetworkResource - PortMap map[string]int + Env map[string]string + TaskMeta map[string]string + TaskGroupMeta map[string]string + JobMeta map[string]string + AllocDir string + TaskDir string + SecretDir string + CpuLimit int + MemLimit int + TaskName string + AllocIndex int + AllocId string + AllocName string + Node *structs.Node + Networks []*structs.NetworkResource + PortMap map[string]int + VaultToken string + InjectVaultToken bool // taskEnv is the variables that will be set in the tasks environment TaskEnv map[string]string @@ -203,6 +208,11 @@ func (t *TaskEnvironment) Build() *TaskEnvironment { } } + // Build the Vault Token + if t.InjectVaultToken && t.VaultToken != "" { + t.TaskEnv[VaultToken] = t.VaultToken + } + // Interpret the environment variables interpreted := make(map[string]string, len(t.Env)) for k, v := range t.Env { @@ -446,3 +456,15 @@ func (t *TaskEnvironment) ClearTaskName() *TaskEnvironment { t.TaskName = "" return t } + +func (t *TaskEnvironment) SetVaultToken(token string, inject bool) *TaskEnvironment { + t.VaultToken = token + t.InjectVaultToken = inject + return t +} + +func (t *TaskEnvironment) ClearVaultToken() *TaskEnvironment { + t.VaultToken = "" + t.InjectVaultToken = false + return t +} diff --git a/client/driver/env/env_test.go b/client/driver/env/env_test.go index 19c980ee137..88c5986e095 100644 --- a/client/driver/env/env_test.go +++ b/client/driver/env/env_test.go @@ -163,6 +163,23 @@ func TestEnvironment_AsList(t *testing.T) { } } +func TestEnvironment_VaultToken(t *testing.T) { + n := mock.Node() + env := NewTaskEnvironment(n).SetVaultToken("123", false).Build() + + act := env.EnvList() + if len(act) != 0 { + t.Fatalf("Unexpected environment variables: %v", act) + } + + env = env.SetVaultToken("123", true).Build() + act = env.EnvList() + exp := []string{"VAULT_TOKEN=123"} + if !reflect.DeepEqual(act, exp) { + t.Fatalf("env.List() returned %v; want %v", act, exp) + } +} + func TestEnvironment_ClearEnvvars(t *testing.T) { n := mock.Node() env := NewTaskEnvironment(n). diff --git a/client/driver/mock_driver.go b/client/driver/mock_driver.go index 4e46a399781..5ff2219bebb 100644 --- a/client/driver/mock_driver.go +++ b/client/driver/mock_driver.go @@ -3,7 +3,9 @@ package driver import ( + "encoding/json" "errors" + "fmt" "log" "time" @@ -90,19 +92,11 @@ func (m *MockDriver) Start(ctx *ExecContext, task *structs.Task) (DriverHandle, return &h, nil } -// TODO implement Open when we need it. -// Open re-connects the driver to the running task -func (m *MockDriver) Open(ctx *ExecContext, handleID string) (DriverHandle, error) { - return nil, nil -} - -// TODO implement Open when we need it. // Validate validates the mock driver configuration func (m *MockDriver) Validate(map[string]interface{}) error { return nil } -// TODO implement Open when we need it. // Fingerprint fingerprints a node and returns if MockDriver is enabled func (m *MockDriver) Fingerprint(cfg *config.Config, node *structs.Node) (bool, error) { node.Attributes["driver.mock_driver"] = "1" @@ -123,12 +117,58 @@ type mockDriverHandle struct { doneCh chan struct{} } -// TODO Implement when we need it. +type mockDriverID struct { + TaskName string + RunFor time.Duration + KillAfter time.Duration + KillTimeout time.Duration + ExitCode int + ExitSignal int + ExitErr error +} + func (h *mockDriverHandle) ID() string { - return "" + id := mockDriverID{ + TaskName: h.taskName, + RunFor: h.runFor, + KillAfter: h.killAfter, + KillTimeout: h.killAfter, + ExitCode: h.exitCode, + ExitSignal: h.exitSignal, + ExitErr: h.exitErr, + } + + data, err := json.Marshal(id) + if err != nil { + h.logger.Printf("[ERR] driver.mock_driver: failed to marshal ID to JSON: %s", err) + } + return string(data) +} + +// Open re-connects the driver to the running task +func (m *MockDriver) Open(ctx *ExecContext, handleID string) (DriverHandle, error) { + id := &mockDriverID{} + if err := json.Unmarshal([]byte(handleID), id); err != nil { + return nil, fmt.Errorf("Failed to parse handle '%s': %v", handleID, err) + } + + h := mockDriverHandle{ + taskName: id.TaskName, + runFor: id.RunFor, + killAfter: id.KillAfter, + killTimeout: id.KillTimeout, + exitCode: id.ExitCode, + exitSignal: id.ExitSignal, + exitErr: id.ExitErr, + logger: m.logger, + doneCh: make(chan struct{}), + waitCh: make(chan *dstructs.WaitResult, 1), + } + + go h.run() + return &h, nil } -// TODO Implement when we need it. func (h *mockDriverHandle) WaitCh() chan *dstructs.WaitResult { return h.waitCh } diff --git a/client/task_runner.go b/client/task_runner.go index 1f7d86051fc..3f3412d2920 100644 --- a/client/task_runner.go +++ b/client/task_runner.go @@ -64,6 +64,11 @@ type TaskRunner struct { // downloaded artifactsDownloaded bool + // vaultToken and vaultRenewalCh are optionally set if the task requires + // Vault tokens + vaultToken string + vaultRenewalCh <-chan error + destroy bool destroyCh chan struct{} destroyLock sync.Mutex @@ -117,6 +122,13 @@ func NewTaskRunner(logger *log.Logger, config *config.Config, return tc } +// SetVaultToken is used to set the Vault token and renewal channel for the task +// runner +func (r *TaskRunner) SetVaultToken(token string, renewalCh <-chan error) { + r.vaultToken = token + r.vaultRenewalCh = renewalCh +} + // MarkReceived marks the task as received. func (r *TaskRunner) MarkReceived() { r.updater(r.task.Name, structs.TaskStatePending, structs.NewTaskEvent(structs.TaskReceived)) @@ -224,7 +236,7 @@ func (r *TaskRunner) setState(state string, event *structs.TaskEvent) { // setTaskEnv sets the task environment. It returns an error if it could not be // created. func (r *TaskRunner) setTaskEnv() error { - taskEnv, err := driver.GetTaskEnv(r.ctx.AllocDir, r.config.Node, r.task.Copy(), r.alloc) + taskEnv, err := driver.GetTaskEnv(r.ctx.AllocDir, r.config.Node, r.task.Copy(), r.alloc, r.vaultToken) if err != nil { return err } @@ -390,7 +402,23 @@ func (r *TaskRunner) run() { if err := r.handleUpdate(update); err != nil { r.logger.Printf("[ERR] client: update to task %q failed: %v", r.task.Name, err) } + case err := <-r.vaultRenewalCh: + if err == nil { + // Only handle once. + continue + } + + // This is a fatal error as the task is not valid if it + // requested a Vault token and the token has now expired. + r.logger.Printf("[WARN] client: vault token for task %q not renewed: %v", r.task.Name, err) + r.Destroy(structs.NewTaskEvent(structs.TaskVaultRenewalFailed).SetVaultRenewalError(err)) + case <-r.destroyCh: + // Store the task event that provides context on the task destroy. + if r.destroyEvent.Type != structs.TaskKilled { + r.setState(structs.TaskStateRunning, r.destroyEvent) + } + // Mark that we received the kill event timeout := driver.GetKillTimeout(r.task.KillTimeout, r.config.MaxKillTimeout) r.setState(structs.TaskStateRunning, @@ -409,11 +437,6 @@ func (r *TaskRunner) run() { // Store that the task has been destroyed and any associated error. r.setState(structs.TaskStateDead, structs.NewTaskEvent(structs.TaskKilled).SetKillError(err)) - // Store the task event that provides context on the task destroy. - if r.destroyEvent.Type != structs.TaskKilled { - r.setState(structs.TaskStateDead, r.destroyEvent) - } - r.runningLock.Lock() r.running = false r.runningLock.Unlock() diff --git a/client/task_runner_test.go b/client/task_runner_test.go index 88c2b6be41b..9b9997f238a 100644 --- a/client/task_runner_test.go +++ b/client/task_runner_test.go @@ -409,3 +409,65 @@ func TestTaskRunner_Validate_UserEnforcement(t *testing.T) { t.Fatalf("unexpected error: %v", err) } } + +func TestTaskRunner_VaultTokenRenewal(t *testing.T) { + alloc := mock.Alloc() + task := alloc.Job.TaskGroups[0].Tasks[0] + task.Driver = "mock_driver" + task.Config = map[string]interface{}{ + "exit_code": "0", + "run_for": "10s", + } + task.Vault = &structs.Vault{ + Policies: []string{"default"}, + } + + upd, tr := testTaskRunnerFromAlloc(false, alloc) + tr.MarkReceived() + renewalCh := make(chan error, 1) + renewalErr := fmt.Errorf("test vault renewal error") + tr.SetVaultToken(structs.GenerateUUID(), renewalCh) + go tr.Run() + defer tr.Destroy(structs.NewTaskEvent(structs.TaskKilled)) + defer tr.ctx.AllocDir.Destroy() + + go func() { + time.Sleep(100 * time.Millisecond) + renewalCh <- renewalErr + close(renewalCh) + }() + + select { + case <-tr.WaitCh(): + case <-time.After(time.Duration(testutil.TestMultiplier()*15) * time.Second): + t.Fatalf("timeout") + } + + if len(upd.events) != 5 { + t.Fatalf("should have 3 updates: %#v", upd.events) + } + + if upd.state != structs.TaskStateDead { + t.Fatalf("TaskState %v; want %v", upd.state, structs.TaskStateDead) + } + + if upd.events[0].Type != structs.TaskReceived { + t.Fatalf("First Event was %v; want %v", upd.events[0].Type, structs.TaskReceived) + } + + if upd.events[1].Type != structs.TaskStarted { + t.Fatalf("Second Event was %v; want %v", upd.events[1].Type, structs.TaskStarted) + } + + if upd.events[2].Type != structs.TaskVaultRenewalFailed { + t.Fatalf("Third Event was %v; want %v", upd.events[2].Type, structs.TaskVaultRenewalFailed) + } + + if upd.events[3].Type != structs.TaskKilling { + t.Fatalf("Fourth Event was %v; want %v", upd.events[3].Type, structs.TaskKilling) + } + + if upd.events[4].Type != structs.TaskKilled { + t.Fatalf("Fifth Event was %v; want %v", upd.events[4].Type, structs.TaskKilled) + } +} diff --git a/client/vaultclient/vaultclient.go b/client/vaultclient/vaultclient.go index e8ee7971676..8f24b8273a5 100644 --- a/client/vaultclient/vaultclient.go +++ b/client/vaultclient/vaultclient.go @@ -38,7 +38,7 @@ type VaultClient interface { // RenewToken renews a token with the given increment and adds it to // the min-heap for periodic renewal. - RenewToken(string, int) <-chan error + RenewToken(string, int) (<-chan error, error) // StopRenewToken removes the token from the min-heap, stopping its // renewal. @@ -46,7 +46,7 @@ type VaultClient interface { // RenewLease renews a vault secret's lease and adds the lease // identifier to the min-heap for periodic renewal. - RenewLease(string, int) <-chan error + RenewLease(string, int) (<-chan error, error) // StopRenewLease removes a secret's lease ID from the min-heap, // stopping its renewal. @@ -65,10 +65,6 @@ type vaultClient struct { // running indicates if the renewal loop is active or not running bool - // connEstablished marks whether the connection to vault was successful - // or not - connEstablished bool - // tokenData is the data of the passed VaultClient token token *tokenData @@ -145,10 +141,6 @@ func NewVaultClient(config *config.VaultConfig, logger *log.Logger, tokenDeriver return nil, fmt.Errorf("nil vault config") } - if !config.Enabled { - return nil, nil - } - if logger == nil { return nil, fmt.Errorf("nil logger") } @@ -157,9 +149,14 @@ func NewVaultClient(config *config.VaultConfig, logger *log.Logger, tokenDeriver config: config, stopCh: make(chan struct{}), // Update channel should be a buffered channel - updateCh: make(chan struct{}, 1), - heap: newVaultClientHeap(), - logger: logger, + updateCh: make(chan struct{}, 1), + heap: newVaultClientHeap(), + logger: logger, + tokenDeriver: tokenDeriver, + } + + if !config.Enabled { + return c, nil } // Get the Vault API configuration @@ -207,52 +204,11 @@ func (c *vaultClient) Start() { return } - c.logger.Printf("[DEBUG] client.vault: establishing connection to vault") - go c.establishConnection() -} - -// ConnectionEstablished indicates whether VaultClient successfully established -// connection to vault or not -func (c *vaultClient) ConnectionEstablished() bool { - c.lock.RLock() - defer c.lock.RUnlock() - return c.connEstablished -} - -// establishConnection is used to make first contact with Vault. This should be -// called in a go-routine since the connection is retried till the Vault Client -// is stopped or the connection is successfully made at which point the renew -// loop is started. -func (c *vaultClient) establishConnection() { - // Create the retry timer and set initial duration to zero so it fires - // immediately - retryTimer := time.NewTimer(0) - -OUTER: - for { - select { - case <-c.stopCh: - return - case <-retryTimer.C: - // Ensure the API is reachable - if _, err := c.client.Sys().InitStatus(); err != nil { - c.logger.Printf("[WARN] client.vault: failed to contact Vault API. Retrying in %v: %v", - c.config.ConnectionRetryIntv, err) - retryTimer.Reset(c.config.ConnectionRetryIntv) - continue OUTER - } - - break OUTER - } - } - c.lock.Lock() - c.connEstablished = true + c.running = true c.lock.Unlock() - // Begin the renewal loop go c.run() - c.logger.Printf("[DEBUG] client.vault: started") } // Stops the renewal loop of vault client @@ -273,6 +229,9 @@ func (c *vaultClient) Stop() { // The return value is a map containing all the unwrapped tokens indexed by the // task name. func (c *vaultClient) DeriveToken(alloc *structs.Allocation, taskNames []string) (map[string]string, error) { + if !c.config.Enabled { + return nil, fmt.Errorf("vault client not enabled") + } if !c.running { return nil, fmt.Errorf("vault client is not running") } @@ -283,6 +242,9 @@ func (c *vaultClient) DeriveToken(alloc *structs.Allocation, taskNames []string) // GetConsulACL creates a vault API client and reads from vault a consul ACL // token used by the task. func (c *vaultClient) GetConsulACL(token, path string) (*vaultapi.Secret, error) { + if !c.config.Enabled { + return nil, fmt.Errorf("vault client not enabled") + } if token == "" { return nil, fmt.Errorf("missing token") } @@ -290,10 +252,6 @@ func (c *vaultClient) GetConsulACL(token, path string) (*vaultapi.Secret, error) return nil, fmt.Errorf("missing consul ACL token vault path") } - if !c.ConnectionEstablished() { - return nil, fmt.Errorf("connection with vault is not yet established") - } - c.lock.Lock() defer c.lock.Unlock() @@ -314,21 +272,19 @@ func (c *vaultClient) GetConsulACL(token, path string) (*vaultapi.Secret, error) // the caller be notified of a renewal failure asynchronously for appropriate // actions to be taken. The caller of this function need not have to close the // error channel. -func (c *vaultClient) RenewToken(token string, increment int) <-chan error { - // Create a buffered error channel - errCh := make(chan error, 1) - +func (c *vaultClient) RenewToken(token string, increment int) (<-chan error, error) { if token == "" { - errCh <- fmt.Errorf("missing token") - close(errCh) - return errCh + err := fmt.Errorf("missing token") + return nil, err } if increment < 1 { - errCh <- fmt.Errorf("increment cannot be less than 1") - close(errCh) - return errCh + err := fmt.Errorf("increment cannot be less than 1") + return nil, err } + // Create a buffered error channel + errCh := make(chan error, 1) + // Create a renewal request and indicate that the identifier in the // request is a token and not a lease renewalReq := &vaultClientRenewalRequest{ @@ -342,9 +298,10 @@ func (c *vaultClient) RenewToken(token string, increment int) <-chan error { // error channel. if err := c.renew(renewalReq); err != nil { c.logger.Printf("[ERR] client.vault: renewal of token failed: %v", err) + return nil, err } - return errCh + return errCh, nil } // RenewLease renews the supplied lease identifier for a supplied duration (in @@ -354,23 +311,20 @@ func (c *vaultClient) RenewToken(token string, increment int) <-chan error { // This helps the caller be notified of a renewal failure asynchronously for // appropriate actions to be taken. The caller of this function need not have // to close the error channel. -func (c *vaultClient) RenewLease(leaseId string, increment int) <-chan error { - c.logger.Printf("[DEBUG] client.vault: renewing lease %q", leaseId) - // Create a buffered error channel - errCh := make(chan error, 1) - +func (c *vaultClient) RenewLease(leaseId string, increment int) (<-chan error, error) { if leaseId == "" { - errCh <- fmt.Errorf("missing lease ID") - close(errCh) - return errCh + err := fmt.Errorf("missing lease ID") + return nil, err } if increment < 1 { - errCh <- fmt.Errorf("increment cannot be less than 1") - close(errCh) - return errCh + err := fmt.Errorf("increment cannot be less than 1") + return nil, err } + // Create a buffered error channel + errCh := make(chan error, 1) + // Create a renewal request using the supplied lease and duration renewalReq := &vaultClientRenewalRequest{ errCh: errCh, @@ -381,9 +335,10 @@ func (c *vaultClient) RenewLease(leaseId string, increment int) <-chan error { // Renew the secret and send any error to the dedicated error channel if err := c.renew(renewalReq); err != nil { c.logger.Printf("[ERR] client.vault: renewal of lease failed: %v", err) + return nil, err } - return errCh + return errCh, nil } // renew is a common method to handle renewal of both tokens and secret leases. @@ -395,6 +350,9 @@ func (c *vaultClient) renew(req *vaultClientRenewalRequest) error { c.lock.Lock() defer c.lock.Unlock() + if !c.config.Enabled { + return fmt.Errorf("vault client not enabled") + } if !c.running { return fmt.Errorf("vault client is not running") } @@ -402,10 +360,15 @@ func (c *vaultClient) renew(req *vaultClientRenewalRequest) error { if req == nil { return fmt.Errorf("nil renewal request") } + if req.errCh == nil { + return fmt.Errorf("renewal request error channel nil") + } if req.id == "" { + close(req.errCh) return fmt.Errorf("missing id in renewal request") } if req.increment < 1 { + close(req.errCh) return fmt.Errorf("increment cannot be less than 1") } @@ -423,8 +386,7 @@ func (c *vaultClient) renew(req *vaultClientRenewalRequest) error { renewResp, err := c.client.Auth().Token().RenewSelf(req.increment) if err != nil { renewalErr = fmt.Errorf("failed to renew the vault token: %v", err) - } - if renewResp == nil || renewResp.Auth == nil { + } else if renewResp == nil || renewResp.Auth == nil { renewalErr = fmt.Errorf("failed to renew the vault token") } else { // Don't set this if renewal fails @@ -435,8 +397,7 @@ func (c *vaultClient) renew(req *vaultClientRenewalRequest) error { renewResp, err := c.client.Sys().Renew(req.id, req.increment) if err != nil { renewalErr = fmt.Errorf("failed to renew vault secret: %v", err) - } - if renewResp == nil { + } else if renewResp == nil { renewalErr = fmt.Errorf("failed to renew vault secret") } else { // Don't set this if renewal fails @@ -463,11 +424,12 @@ func (c *vaultClient) renew(req *vaultClientRenewalRequest) error { fatal := false if renewalErr != nil && (strings.Contains(renewalErr.Error(), "lease not found or lease is not renewable") || - strings.Contains(renewalErr.Error(), "token not found")) { + strings.Contains(renewalErr.Error(), "token not found") || + strings.Contains(renewalErr.Error(), "permission denied")) { fatal = true } else if renewalErr != nil { c.logger.Printf("[DEBUG] client.vault: req.increment: %d, leaseDuration: %d, duration: %d", req.increment, leaseDuration, duration) - c.logger.Printf("[ERR] client.vault: renewal of lease or token failed due to a non-fatal error. Retrying at %v", next.String()) + c.logger.Printf("[ERR] client.vault: renewal of lease or token failed due to a non-fatal error. Retrying at %v: %v", next.String(), renewalErr) } if c.isTracked(req.id) { @@ -537,10 +499,6 @@ func (c *vaultClient) run() { return } - c.lock.Lock() - c.running = true - c.lock.Unlock() - var renewalCh <-chan time.Time for c.config.Enabled && c.running { // Fetches the candidate for next renewal diff --git a/client/vaultclient/vaultclient_test.go b/client/vaultclient/vaultclient_test.go index 3ff1b128ba4..e55fd35bf3e 100644 --- a/client/vaultclient/vaultclient_test.go +++ b/client/vaultclient/vaultclient_test.go @@ -11,38 +11,6 @@ import ( vaultapi "github.com/hashicorp/vault/api" ) -func TestVaultClient_EstablishConnection(t *testing.T) { - v := testutil.NewTestVault(t) - - logger := log.New(os.Stderr, "TEST: ", log.Lshortfile|log.LstdFlags) - v.Config.ConnectionRetryIntv = 100 * time.Millisecond - v.Config.TaskTokenTTL = "10s" - c, err := NewVaultClient(v.Config, logger, nil) - if err != nil { - t.Fatalf("failed to build vault client: %v", err) - } - - c.Start() - defer c.Stop() - - // Sleep a little while and check that no connection has been established. - time.Sleep(100 * time.Duration(testutil.TestMultiplier()) * time.Millisecond) - - if c.ConnectionEstablished() { - t.Fatalf("ConnectionEstablished() returned true before Vault server started") - } - - // Start Vault - v.Start() - defer v.Stop() - - testutil.WaitForResult(func() (bool, error) { - return c.ConnectionEstablished(), nil - }, func(err error) { - t.Fatalf("Connection not established") - }) -} - func TestVaultClient_TokenRenewals(t *testing.T) { v := testutil.NewTestVault(t).Start() defer v.Stop() @@ -89,12 +57,15 @@ func TestVaultClient_TokenRenewals(t *testing.T) { tokens[i] = secret.Auth.ClientToken - errCh := c.RenewToken(tokens[i], secret.Auth.LeaseDuration) + errCh, err := c.RenewToken(tokens[i], secret.Auth.LeaseDuration) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + go func(errCh <-chan error) { - var err error for { select { - case err = <-errCh: + case err := <-errCh: t.Fatalf("error while renewing the token: %v", err) } } @@ -105,7 +76,7 @@ func TestVaultClient_TokenRenewals(t *testing.T) { t.Fatalf("bad: heap length: expected: %d, actual: %d", num, c.heap.Length()) } - time.Sleep(5 * time.Second) + time.Sleep(time.Duration(5*testutil.TestMultiplier()) * time.Second) for i := 0; i < num; i++ { if err := c.StopRenewToken(tokens[i]); err != nil { diff --git a/client/vaultclient/vaultclient_testing.go b/client/vaultclient/vaultclient_testing.go new file mode 100644 index 00000000000..7f9310068c3 --- /dev/null +++ b/client/vaultclient/vaultclient_testing.go @@ -0,0 +1,90 @@ +package vaultclient + +import ( + "github.com/hashicorp/nomad/nomad/structs" + vaultapi "github.com/hashicorp/vault/api" +) + +// MockVaultClient is used for testing the vaultclient integration +type MockVaultClient struct { + // StoppedTokens tracks the tokens that have stopped renewing + StoppedTokens []string + + // RenewTokens are the tokens that have been renewed and their error + // channels + RenewTokens map[string]chan error + + // RenewTokenErrors is used to return an error when the RenewToken is called + // with the given token + RenewTokenErrors map[string]error + + // DeriveTokenErrors maps an allocation ID and tasks to an error when the + // token is derived + DeriveTokenErrors map[string]map[string]error +} + +// NewMockVaultClient returns a MockVaultClient for testing +func NewMockVaultClient() *MockVaultClient { return &MockVaultClient{} } + +func (vc *MockVaultClient) DeriveToken(a *structs.Allocation, tasks []string) (map[string]string, error) { + tokens := make(map[string]string, len(tasks)) + for _, task := range tasks { + if tasks, ok := vc.DeriveTokenErrors[a.ID]; ok { + if err, ok := tasks[task]; ok { + return nil, err + } + } + + tokens[task] = structs.GenerateUUID() + } + + return tokens, nil +} + +func (vc *MockVaultClient) SetDeriveTokenError(allocID string, tasks []string, err error) { + if vc.DeriveTokenErrors == nil { + vc.DeriveTokenErrors = make(map[string]map[string]error, 10) + } + + if _, ok := vc.RenewTokenErrors[allocID]; !ok { + vc.DeriveTokenErrors[allocID] = make(map[string]error, 10) + } + + for _, task := range tasks { + vc.DeriveTokenErrors[allocID][task] = err + } +} + +func (vc *MockVaultClient) RenewToken(token string, interval int) (<-chan error, error) { + if err, ok := vc.RenewTokenErrors[token]; ok { + return nil, err + } + + renewCh := make(chan error) + if vc.RenewTokens == nil { + vc.RenewTokens = make(map[string]chan error, 10) + } + vc.RenewTokens[token] = renewCh + return renewCh, nil +} + +func (vc *MockVaultClient) SetRenewTokenError(token string, err error) { + if vc.RenewTokenErrors == nil { + vc.RenewTokenErrors = make(map[string]error, 10) + } + + vc.RenewTokenErrors[token] = err +} + +func (vc *MockVaultClient) StopRenewToken(token string) error { + vc.StoppedTokens = append(vc.StoppedTokens, token) + return nil +} + +func (vc *MockVaultClient) RenewLease(leaseId string, interval int) (<-chan error, error) { + return nil, nil +} +func (vc *MockVaultClient) StopRenewLease(leaseId string) error { return nil } +func (vc *MockVaultClient) Start() {} +func (vc *MockVaultClient) Stop() {} +func (vc *MockVaultClient) GetConsulACL(string, string) (*vaultapi.Secret, error) { return nil, nil } diff --git a/command/alloc_status.go b/command/alloc_status.go index 71f859c0765..1f5c6bb61da 100644 --- a/command/alloc_status.go +++ b/command/alloc_status.go @@ -212,6 +212,7 @@ func (c *AllocStatusCommand) Run(args []string) int { fmt.Sprintf("Node ID|%s", limit(alloc.NodeID, length)), fmt.Sprintf("Job ID|%s", alloc.JobID), fmt.Sprintf("Client Status|%s", alloc.ClientStatus), + fmt.Sprintf("Client Description|%s", alloc.ClientDescription), fmt.Sprintf("Created At|%s", formatUnixNanoTime(alloc.CreateTime)), } @@ -334,6 +335,24 @@ func (c *AllocStatusCommand) outputTaskStatus(state *api.TaskState) { } else { desc = "Task exceeded restart policy" } + case api.TaskDiskExceeded: + if event.DiskLimit != 0 && event.DiskSize != 0 { + desc = fmt.Sprintf("Disk size exceeded maximum: %d > %d", event.DiskSize, event.DiskLimit) + } else { + desc = "Task exceeded disk quota" + } + case api.TaskVaultRenewalFailed: + if event.VaultError != "" { + desc = event.VaultError + } else { + desc = "Task's Vault token failed to be renewed" + } + case api.TaskSiblingFailed: + if event.FailedSibling != "" { + desc = fmt.Sprintf("Task's sibling %q failed", event.FailedSibling) + } else { + desc = "Task's sibling failed" + } } // Reverse order so we are sorted by time diff --git a/nomad/structs/structs.go b/nomad/structs/structs.go index b06a68307eb..dad219bca74 100644 --- a/nomad/structs/structs.go +++ b/nomad/structs/structs.go @@ -2209,7 +2209,8 @@ func (ts *TaskState) Failed() bool { } switch ts.Events[l-1].Type { - case TaskDiskExceeded, TaskNotRestarting, TaskArtifactDownloadFailed, TaskFailedValidation: + case TaskDiskExceeded, TaskNotRestarting, TaskArtifactDownloadFailed, + TaskFailedValidation, TaskVaultRenewalFailed: return true default: return false @@ -2279,6 +2280,9 @@ const ( // TaskSiblingFailed indicates that a sibling task in the task group has // failed. TaskSiblingFailed = "Sibling task failed" + + // TaskVaultRenewalFailed indicates that Vault token renewal failed + TaskVaultRenewalFailed = "Vault token renewal failed" ) // TaskEvent is an event that effects the state of a task and contains meta-data @@ -2322,6 +2326,9 @@ type TaskEvent struct { // Name of the sibling task that caused termination of the task that // the TaskEvent refers to. FailedSibling string + + // VaultError is the error from token renewal + VaultError string } func (te *TaskEvent) GoString() string { @@ -2419,6 +2426,13 @@ func (e *TaskEvent) SetFailedSibling(sibling string) *TaskEvent { return e } +func (e *TaskEvent) SetVaultRenewalError(err error) *TaskEvent { + if err != nil { + e.VaultError = err.Error() + } + return e +} + // TaskArtifact is an artifact to download before running the task. type TaskArtifact struct { // GetterSource is the source to download an artifact using go-getter diff --git a/nomad/vault.go b/nomad/vault.go index 830d1e85a51..119335553ca 100644 --- a/nomad/vault.go +++ b/nomad/vault.go @@ -320,8 +320,6 @@ OUTER: } } - atomic.StoreInt32(&v.connEstablished, 1) - // Retrieve our token, validate it and parse the lease duration if err := v.parseSelfToken(); err != nil { v.logger.Printf("[ERR] vault: failed to lookup self token and not retrying: %v", err) @@ -340,6 +338,8 @@ OUTER: time.Duration(v.tokenData.CreationTTL)*time.Second) v.tomb.Go(wrapNilError(v.renewalLoop)) } + + atomic.StoreInt32(&v.connEstablished, 1) } // renewalLoop runs the renew loop. This should only be called if we are given a @@ -562,7 +562,7 @@ func (v *vaultClient) CreateToken(ctx context.Context, a *structs.Allocation, ta "NodeID": a.NodeID, }, TTL: v.childTTL, - DisplayName: fmt.Sprintf("%s: %s", a.ID, task), + DisplayName: fmt.Sprintf("%s-%s", a.ID, task), } // Ensure we are under our rate limit