Skip to content

Commit

Permalink
Merge pull request #1713 from hashicorp/f-alloc-runner-vault
Browse files Browse the repository at this point in the history
Vault integration in client
  • Loading branch information
dadgar authored Sep 20, 2016
2 parents e5fd8e6 + 5c3acf1 commit d49dda4
Show file tree
Hide file tree
Showing 18 changed files with 933 additions and 232 deletions.
6 changes: 6 additions & 0 deletions api/tasks.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,8 @@ const (
TaskDownloadingArtifacts = "Downloading Artifacts"
TaskArtifactDownloadFailed = "Failed Artifact Download"
TaskDiskExceeded = "Disk Exceeded"
TaskVaultRenewalFailed = "Vault token renewal failed"
TaskSiblingFailed = "Sibling task failed"
)

// TaskEvent is an event that effects the state of a task and contains meta-data
Expand All @@ -250,4 +252,8 @@ type TaskEvent struct {
StartDelay int64
DownloadError string
ValidationError string
DiskLimit int64
DiskSize int64
FailedSibling string
VaultError string
}
263 changes: 234 additions & 29 deletions client/alloc_runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package client

import (
"fmt"
"io/ioutil"
"log"
"os"
"path/filepath"
Expand All @@ -12,6 +13,7 @@ import (
"github.com/hashicorp/nomad/client/allocdir"
"github.com/hashicorp/nomad/client/config"
"github.com/hashicorp/nomad/client/driver"
"github.com/hashicorp/nomad/client/vaultclient"
"github.com/hashicorp/nomad/nomad/structs"

cstructs "github.com/hashicorp/nomad/client/structs"
Expand All @@ -29,6 +31,10 @@ const (
// watchdogInterval is the interval at which resource constraints for the
// allocation are being checked and enforced.
watchdogInterval = 5 * time.Second

// vaultTokenFile is the name of the file holding the Vault token inside the
// task's secret directory
vaultTokenFile = "vault_token"
)

// AllocStateUpdater is used to update the status of an allocation
Expand Down Expand Up @@ -62,6 +68,9 @@ type AllocRunner struct {

updateCh chan *structs.Allocation

vaultClient vaultclient.VaultClient
vaultTokens map[string]vaultToken

destroy bool
destroyCh chan struct{}
destroyLock sync.Mutex
Expand All @@ -82,19 +91,20 @@ type allocRunnerState struct {

// NewAllocRunner is used to create a new allocation context
func NewAllocRunner(logger *log.Logger, config *config.Config, updater AllocStateUpdater,
alloc *structs.Allocation) *AllocRunner {
alloc *structs.Allocation, vaultClient vaultclient.VaultClient) *AllocRunner {
ar := &AllocRunner{
config: config,
updater: updater,
logger: logger,
alloc: alloc,
dirtyCh: make(chan struct{}, 1),
tasks: make(map[string]*TaskRunner),
taskStates: copyTaskStates(alloc.TaskStates),
restored: make(map[string]struct{}),
updateCh: make(chan *structs.Allocation, 64),
destroyCh: make(chan struct{}),
waitCh: make(chan struct{}),
config: config,
updater: updater,
logger: logger,
alloc: alloc,
dirtyCh: make(chan struct{}, 1),
tasks: make(map[string]*TaskRunner),
taskStates: copyTaskStates(alloc.TaskStates),
restored: make(map[string]struct{}),
updateCh: make(chan *structs.Allocation, 64),
destroyCh: make(chan struct{}),
waitCh: make(chan struct{}),
vaultClient: vaultClient,
}
return ar
}
Expand Down Expand Up @@ -133,6 +143,9 @@ func (r *AllocRunner) RestoreState() error {
return e
}

// Recover the Vault tokens
vaultErr := r.recoverVaultTokens()

// Restore the task runners
var mErr multierror.Error
for name, state := range r.taskStates {
Expand All @@ -144,6 +157,10 @@ func (r *AllocRunner) RestoreState() error {
task)
r.tasks[name] = tr

if vt, ok := r.vaultTokens[name]; ok {
tr.SetVaultToken(vt.token, vt.renewalCh)
}

// Skip tasks in terminal states.
if state.State == structs.TaskStateDead {
continue
Expand All @@ -157,6 +174,21 @@ func (r *AllocRunner) RestoreState() error {
go tr.Run()
}
}

// Since this is somewhat of an expected case we do not return an error but
// handle it gracefully.
if vaultErr != nil {
msg := fmt.Sprintf("failed to recover Vault tokens for allocation %q: %v", r.alloc.ID, vaultErr)
r.logger.Printf("[ERR] client: %s", msg)
r.setStatus(structs.AllocClientStatusFailed, msg)

// Destroy the task runners and set the error
r.destroyTaskRunners(structs.NewTaskEvent(structs.TaskVaultRenewalFailed).SetVaultRenewalError(vaultErr))

// Handle cleanup
go r.handleDestroy()
}

return mErr.ErrorOrNil()
}

Expand Down Expand Up @@ -333,17 +365,26 @@ func (r *AllocRunner) setTaskState(taskName, state string, event *structs.TaskEv
taskState.State = state
r.appendTaskEvent(taskState, event)

// If the task failed, we should kill all the other tasks in the task group.
if state == structs.TaskStateDead && taskState.Failed() {
var destroyingTasks []string
for task, tr := range r.tasks {
if task != taskName {
destroyingTasks = append(destroyingTasks, task)
tr.Destroy(structs.NewTaskEvent(structs.TaskSiblingFailed).SetFailedSibling(taskName))
if state == structs.TaskStateDead {
// If the task has a Vault token, stop renewing it
if vt, ok := r.vaultTokens[taskName]; ok {
if err := r.vaultClient.StopRenewToken(vt.token); err != nil {
r.logger.Printf("[ERR] client: stopping token renewal for task %q failed: %v", taskName, err)
}
}
if len(destroyingTasks) > 0 {
r.logger.Printf("[DEBUG] client: task %q failed, destroying other tasks in task group: %v", taskName, destroyingTasks)

// If the task failed, we should kill all the other tasks in the task group.
if taskState.Failed() {
var destroyingTasks []string
for task, tr := range r.tasks {
if task != taskName {
destroyingTasks = append(destroyingTasks, task)
tr.Destroy(structs.NewTaskEvent(structs.TaskSiblingFailed).SetFailedSibling(taskName))
}
}
if len(destroyingTasks) > 0 {
r.logger.Printf("[DEBUG] client: task %q failed, destroying other tasks in task group: %v", taskName, destroyingTasks)
}
}
}

Expand Down Expand Up @@ -408,6 +449,15 @@ func (r *AllocRunner) Run() {
return
}

// Request Vault tokens for the tasks that require them
err := r.deriveVaultTokens()
if err != nil {
msg := fmt.Sprintf("failed to derive Vault token for allocation %q: %v", r.alloc.ID, err)
r.logger.Printf("[ERR] client: %s", msg)
r.setStatus(structs.AllocClientStatusFailed, msg)
return
}

// Start the task runners
r.logger.Printf("[DEBUG] client: starting task runners for alloc '%s'", r.alloc.ID)
r.taskLock.Lock()
Expand All @@ -416,10 +466,15 @@ func (r *AllocRunner) Run() {
continue
}

tr := NewTaskRunner(r.logger, r.config, r.setTaskState, r.ctx, r.Alloc(),
task.Copy())
tr := NewTaskRunner(r.logger, r.config, r.setTaskState, r.ctx, r.Alloc(), task.Copy())
r.tasks[task.Name] = tr
tr.MarkReceived()

// If the task has a vault token set it before running
if vt, ok := r.vaultTokens[task.Name]; ok {
tr.SetVaultToken(vt.token, vt.renewalCh)
}

go tr.Run()
}
r.taskLock.Unlock()
Expand Down Expand Up @@ -467,10 +522,24 @@ OUTER:
}
}

// Kill the task runners
r.destroyTaskRunners(taskDestroyEvent)

// Stop watching the shared allocation directory
r.ctx.AllocDir.StopDiskWatcher()

// Block until we should destroy the state of the alloc
r.handleDestroy()
r.logger.Printf("[DEBUG] client: terminating runner for alloc '%s'", r.alloc.ID)
}

// destroyTaskRunners destroys the task runners, waits for them to terminate and
// then saves state.
func (r *AllocRunner) destroyTaskRunners(destroyEvent *structs.TaskEvent) {
// Destroy each sub-task
runners := r.getTaskRunners()
for _, tr := range runners {
tr.Destroy(taskDestroyEvent)
tr.Destroy(destroyEvent)
}

// Wait for termination of the task runners
Expand All @@ -480,13 +549,149 @@ OUTER:

// Final state sync
r.syncStatus()
}

// Stop watching the shared allocation directory
r.ctx.AllocDir.StopDiskWatcher()
// vaultToken acts as a tuple of the token and renewal channel
type vaultToken struct {
token string
renewalCh <-chan error
}

// Block until we should destroy the state of the alloc
r.handleDestroy()
r.logger.Printf("[DEBUG] client: terminating runner for alloc '%s'", r.alloc.ID)
// deriveVaultTokens derives the required vault tokens and returns a map of the
// tasks to their respective vault token and renewal channel. This must be
// called after the allocation directory is created as the vault tokens are
// written to disk.
func (r *AllocRunner) deriveVaultTokens() error {
required, err := r.tasksRequiringVaultTokens()
if err != nil {
return err
}

if len(required) == 0 {
return nil
}

if r.vaultTokens == nil {
r.vaultTokens = make(map[string]vaultToken, len(required))
}

// Get the tokens
tokens, err := r.vaultClient.DeriveToken(r.Alloc(), required)
if err != nil {
return fmt.Errorf("failed to derive Vault tokens: %v", err)
}

// Persist the tokens to the appropriate secret directories
adir := r.ctx.AllocDir
for task, token := range tokens {
// Has been recovered
if _, ok := r.vaultTokens[task]; ok {
continue
}

secretDir, err := adir.GetSecretDir(task)
if err != nil {
return fmt.Errorf("failed to determine task %s secret dir in alloc %q: %v", task, r.alloc.ID, err)
}

// Write the token to the file system
tokenPath := filepath.Join(secretDir, vaultTokenFile)
if err := ioutil.WriteFile(tokenPath, []byte(token), 0777); err != nil {
return fmt.Errorf("failed to save Vault tokens to secret dir for task %q in alloc %q: %v", task, r.alloc.ID, err)
}

// Start renewing the token
renewCh, err := r.vaultClient.RenewToken(token, 10)
if err != nil {
var mErr multierror.Error
errMsg := fmt.Errorf("failed to renew Vault token for task %q in alloc %q: %v", task, r.alloc.ID, err)
multierror.Append(&mErr, errMsg)

// Clean up any token that we have started renewing
for _, token := range r.vaultTokens {
if err := r.vaultClient.StopRenewToken(token.token); err != nil {
multierror.Append(&mErr, err)
}
}

return mErr.ErrorOrNil()
}
r.vaultTokens[task] = vaultToken{token: token, renewalCh: renewCh}
}

return nil
}

// tasksRequiringVaultTokens returns the set of tasks that require a Vault token
func (r *AllocRunner) tasksRequiringVaultTokens() ([]string, error) {
// Get the tasks
tg := r.alloc.Job.LookupTaskGroup(r.alloc.TaskGroup)
if tg == nil {
return nil, fmt.Errorf("Failed to lookup task group in alloc")
}

// Retrieve any required Vault tokens
var required []string
for _, task := range tg.Tasks {
if task.Vault != nil && len(task.Vault.Policies) != 0 {
required = append(required, task.Name)
}
}

return required, nil
}

// recoverVaultTokens reads the Vault tokens for the tasks that have Vault
// tokens off disk. If there is an error, it is returned, otherwise token
// renewal is started.
func (r *AllocRunner) recoverVaultTokens() error {
required, err := r.tasksRequiringVaultTokens()
if err != nil {
return err
}

if len(required) == 0 {
return nil
}

// Read the tokens and start renewing them
adir := r.ctx.AllocDir
renewingTokens := make(map[string]vaultToken, len(required))
for _, task := range required {
secretDir, err := adir.GetSecretDir(task)
if err != nil {
return fmt.Errorf("failed to determine task %s secret dir in alloc %q: %v", task, r.alloc.ID, err)
}

// Read the token from the secret directory
tokenPath := filepath.Join(secretDir, vaultTokenFile)
data, err := ioutil.ReadFile(tokenPath)
if err != nil {
return fmt.Errorf("failed to read token for task %q in alloc %q: %v", task, r.alloc.ID, err)
}

token := string(data)
renewCh, err := r.vaultClient.RenewToken(token, 10)
if err != nil {
var mErr multierror.Error
errMsg := fmt.Errorf("failed to renew Vault token for task %q in alloc %q: %v", task, r.alloc.ID, err)
multierror.Append(&mErr, errMsg)

// Clean up any token that we have started renewing
for _, token := range renewingTokens {
if err := r.vaultClient.StopRenewToken(token.token); err != nil {
multierror.Append(&mErr, err)
}
}

return mErr.ErrorOrNil()
}

renewingTokens[task] = vaultToken{token: token, renewalCh: renewCh}
}

r.vaultTokens = renewingTokens
return nil
}

// checkResources monitors and enforces alloc resource usage. It returns an
Expand Down
Loading

0 comments on commit d49dda4

Please sign in to comment.