diff --git a/client/allocrunner/alloc_runner.go b/client/allocrunner/alloc_runner.go index 2179dcf2d0a..65b57876c39 100644 --- a/client/allocrunner/alloc_runner.go +++ b/client/allocrunner/alloc_runner.go @@ -157,7 +157,7 @@ func (ar *allocRunner) initTaskRunners(tasks []*structs.Task) error { StateDB: ar.stateDB, StateUpdater: ar, Consul: ar.consulClient, - VaultClient: ar.vaultClient, + Vault: ar.vaultClient, PluginSingletonLoader: ar.pluginSingletonLoader, } @@ -181,17 +181,61 @@ func (ar *allocRunner) Run() { ar.destroyedLock.Lock() defer ar.destroyedLock.Unlock() + // Run should not be called after Destroy is called. This is a + // programming error. if ar.destroyed { - // Run should not be called after Destroy is called. This is a - // programming error. ar.logger.Error("alloc destroyed; cannot run") return } - ar.runLaunched = true + // If an alloc should not be run, ensure any restored task handles are + // destroyed and exit to wait for the AR to be GC'd by the client. + if !ar.shouldRun() { + ar.logger.Debug("not running terminal alloc") + + // Cleanup and sync state + states := ar.killTasks() + + // Get the client allocation + calloc := ar.clientAlloc(states) + + // Update the server + ar.stateUpdater.AllocStateUpdated(calloc) + + // Broadcast client alloc to listeners + ar.allocBroadcaster.Send(calloc) + return + } + + // Run! (and mark as having been run to ensure Destroy cleans up properly) + ar.runLaunched = true go ar.runImpl() } +// shouldRun returns true if the alloc is in a state that the alloc runner +// should run it. +func (ar *allocRunner) shouldRun() bool { + // Do not run allocs that are terminal + if ar.Alloc().TerminalStatus() { + ar.logger.Trace("alloc terminal; not running", + "desired_status", ar.Alloc().DesiredStatus, + "client_status", ar.Alloc().ClientStatus, + ) + return false + } + + // It's possible that the alloc local state was marked terminal before + // the server copy of the alloc (checked above) was marked as terminal, + // so check the local state as well. + switch clientStatus := ar.AllocState().ClientStatus; clientStatus { + case structs.AllocClientStatusComplete, structs.AllocClientStatusFailed, structs.AllocClientStatusLost: + ar.logger.Trace("alloc terminal; updating server and not running", "status", clientStatus) + return false + } + + return true +} + func (ar *allocRunner) runImpl() { // Close the wait channel on return defer close(ar.waitCh) @@ -354,7 +398,7 @@ func (ar *allocRunner) handleTaskStateUpdates() { ar.logger.Debug("task failure, destroying all tasks", "failed_task", killTask) } - ar.killTasks() + states = ar.killTasks() } // Get the client allocation @@ -369,8 +413,12 @@ func (ar *allocRunner) handleTaskStateUpdates() { } // killTasks kills all task runners, leader (if there is one) first. Errors are -// logged except taskrunner.ErrTaskNotRunning which is ignored. -func (ar *allocRunner) killTasks() { +// logged except taskrunner.ErrTaskNotRunning which is ignored. Task states +// after Kill has been called are returned. +func (ar *allocRunner) killTasks() map[string]*structs.TaskState { + var mu sync.Mutex + states := make(map[string]*structs.TaskState, len(ar.tasks)) + // Kill leader first, synchronously for name, tr := range ar.tasks { if !tr.IsLeader() { @@ -381,6 +429,9 @@ func (ar *allocRunner) killTasks() { if err != nil && err != taskrunner.ErrTaskNotRunning { ar.logger.Warn("error stopping leader task", "error", err, "task_name", name) } + + state := tr.TaskState() + states[name] = state break } @@ -398,9 +449,16 @@ func (ar *allocRunner) killTasks() { if err != nil && err != taskrunner.ErrTaskNotRunning { ar.logger.Warn("error stopping task", "error", err, "task_name", name) } + + state := tr.TaskState() + mu.Lock() + states[name] = state + mu.Unlock() }(name, tr) } wg.Wait() + + return states } // clientAlloc takes in the task states and returns an Allocation populated @@ -510,6 +568,12 @@ func (ar *allocRunner) AllocState() *state.State { } } + // Generate alloc to get other state fields + alloc := ar.clientAlloc(state.TaskStates) + state.ClientStatus = alloc.ClientStatus + state.ClientDescription = alloc.ClientDescription + state.DeploymentStatus = alloc.DeploymentStatus + return state } @@ -563,8 +627,11 @@ func (ar *allocRunner) Destroy() { } defer ar.destroyedLock.Unlock() - // Stop any running tasks - ar.killTasks() + // Stop any running tasks and persist states in case the client is + // shutdown before Destroy finishes. + states := ar.killTasks() + calloc := ar.clientAlloc(states) + ar.stateUpdater.AllocStateUpdated(calloc) // Wait for tasks to exit and postrun hooks to finish (if they ran at all) if ar.runLaunched { diff --git a/client/allocrunner/alloc_runner_test.go b/client/allocrunner/alloc_runner_test.go index 48b248f4ac2..bb117ce94e4 100644 --- a/client/allocrunner/alloc_runner_test.go +++ b/client/allocrunner/alloc_runner_test.go @@ -11,7 +11,6 @@ import ( consulapi "github.com/hashicorp/nomad/client/consul" "github.com/hashicorp/nomad/client/state" "github.com/hashicorp/nomad/client/vaultclient" - "github.com/hashicorp/nomad/helper/testlog" "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/plugins/shared/catalog" @@ -57,20 +56,19 @@ func (m *MockStateUpdater) Reset() { // testAllocRunnerConfig returns a new allocrunner.Config with mocks and noop // versions of dependencies along with a cleanup func. func testAllocRunnerConfig(t *testing.T, alloc *structs.Allocation) (*Config, func()) { - logger := testlog.HCLogger(t) pluginLoader := catalog.TestPluginLoader(t) clientConf, cleanup := config.TestClientConfig(t) conf := &Config{ // Copy the alloc in case the caller edits and reuses it Alloc: alloc.Copy(), - Logger: logger, + Logger: clientConf.Logger, ClientConfig: clientConf, StateDB: state.NoopDB{}, - Consul: consulapi.NewMockConsulServiceClient(t, logger), + Consul: consulapi.NewMockConsulServiceClient(t, clientConf.Logger), Vault: vaultclient.NewMockVaultClient(), StateUpdater: &MockStateUpdater{}, PrevAllocWatcher: allocwatcher.NoopPrevAlloc{}, - PluginSingletonLoader: singleton.NewSingletonLoader(logger, pluginLoader), + PluginSingletonLoader: singleton.NewSingletonLoader(clientConf.Logger, pluginLoader), } return conf, cleanup } diff --git a/client/allocrunner/interfaces/task_lifecycle.go b/client/allocrunner/interfaces/task_lifecycle.go index b22ef285bc1..808afd37a73 100644 --- a/client/allocrunner/interfaces/task_lifecycle.go +++ b/client/allocrunner/interfaces/task_lifecycle.go @@ -109,8 +109,8 @@ type TaskKillResponse struct{} type TaskKillHook interface { TaskHook - // Kill is called when a task is going to be killed. - Kill(context.Context, *TaskKillRequest, *TaskKillResponse) error + // Killing is called when a task is going to be Killed or Restarted. + Killing(context.Context, *TaskKillRequest, *TaskKillResponse) error } type TaskExitedRequest struct{} diff --git a/client/allocrunner/taskrunner/interfaces/lifecycle.go b/client/allocrunner/taskrunner/interfaces/lifecycle.go index 84bbda228e9..1890471bf56 100644 --- a/client/allocrunner/taskrunner/interfaces/lifecycle.go +++ b/client/allocrunner/taskrunner/interfaces/lifecycle.go @@ -7,7 +7,13 @@ import ( ) type TaskLifecycle interface { + // Restart a task in place. If failure=false then the restart does not + // count as an attempt in the restart policy. Restart(ctx context.Context, event *structs.TaskEvent, failure bool) error + + // Sends a signal to a task. Signal(event *structs.TaskEvent, signal string) error + + // Kill a task permanently. Kill(ctx context.Context, event *structs.TaskEvent) error } diff --git a/client/allocrunner/taskrunner/lifecycle.go b/client/allocrunner/taskrunner/lifecycle.go index 7a314648080..224a3af6225 100644 --- a/client/allocrunner/taskrunner/lifecycle.go +++ b/client/allocrunner/taskrunner/lifecycle.go @@ -12,6 +12,7 @@ import ( func (tr *TaskRunner) Restart(ctx context.Context, event *structs.TaskEvent, failure bool) error { // Grab the handle handle := tr.getDriverHandle() + // Check it is running if handle == nil { return ErrTaskNotRunning @@ -20,12 +21,14 @@ func (tr *TaskRunner) Restart(ctx context.Context, event *structs.TaskEvent, fai // Emit the event since it may take a long time to kill tr.EmitEvent(event) + // Run the hooks prior to restarting the task + tr.killing() + // Tell the restart tracker that a restart triggered the exit tr.restartTracker.SetRestartTriggered(failure) // Kill the task using an exponential backoff in-case of failures. - destroySuccess, err := tr.handleDestroy(handle) - if !destroySuccess { + if err := tr.killTask(handle); err != nil { // We couldn't successfully destroy the resource created. tr.logger.Error("failed to kill task. Resources may have been leaked", "error", err) } @@ -36,7 +39,10 @@ func (tr *TaskRunner) Restart(ctx context.Context, event *structs.TaskEvent, fai return err } - <-waitCh + select { + case <-waitCh: + case <-ctx.Done(): + } return nil } @@ -61,7 +67,7 @@ func (tr *TaskRunner) Signal(event *structs.TaskEvent, s string) error { func (tr *TaskRunner) Kill(ctx context.Context, event *structs.TaskEvent) error { // Cancel the task runner to break out of restart delay or the main run // loop. - tr.ctxCancel() + tr.killCtxCancel() // Grab the handle handle := tr.getDriverHandle() @@ -75,16 +81,17 @@ func (tr *TaskRunner) Kill(ctx context.Context, event *structs.TaskEvent) error tr.EmitEvent(event) // Run the hooks prior to killing the task - tr.kill() + tr.killing() - // Tell the restart tracker that the task has been killed + // Tell the restart tracker that the task has been killed so it doesn't + // attempt to restart it. tr.restartTracker.SetKilled() // Kill the task using an exponential backoff in-case of failures. - destroySuccess, destroyErr := tr.handleDestroy(handle) - if !destroySuccess { + killErr := tr.killTask(handle) + if killErr != nil { // We couldn't successfully destroy the resource created. - tr.logger.Error("failed to kill task. Resources may have been leaked", "error", destroyErr) + tr.logger.Error("failed to kill task. Resources may have been leaked", "error", killErr) } // Block until task has exited. @@ -100,13 +107,16 @@ func (tr *TaskRunner) Kill(ctx context.Context, event *structs.TaskEvent) error return err } - <-waitCh + select { + case <-waitCh: + case <-ctx.Done(): + } // Store that the task has been destroyed and any associated error. - tr.UpdateState(structs.TaskStateDead, structs.NewTaskEvent(structs.TaskKilled).SetKillError(destroyErr)) + tr.UpdateState(structs.TaskStateDead, structs.NewTaskEvent(structs.TaskKilled).SetKillError(killErr)) - if destroyErr != nil { - return destroyErr + if killErr != nil { + return killErr } else if err := ctx.Err(); err != nil { return err } diff --git a/client/allocrunner/taskrunner/service_hook.go b/client/allocrunner/taskrunner/service_hook.go index 801cc20737f..a754cbcef94 100644 --- a/client/allocrunner/taskrunner/service_hook.go +++ b/client/allocrunner/taskrunner/service_hook.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "sync" + "time" log "github.com/hashicorp/go-hclog" "github.com/hashicorp/nomad/client/allocrunner/interfaces" @@ -34,6 +35,7 @@ type serviceHook struct { logger log.Logger // The following fields may be updated + delay time.Duration driverExec tinterfaces.ScriptExecutor driverNet *cstructs.DriverNetwork canary bool @@ -53,6 +55,7 @@ func newServiceHook(c serviceHookConfig) *serviceHook { taskName: c.task.Name, services: c.task.Services, restarter: c.restarter, + delay: c.task.ShutdownDelay, } if res := c.alloc.TaskResources[c.task.Name]; res != nil { @@ -111,6 +114,7 @@ func (h *serviceHook) Update(ctx context.Context, req *interfaces.TaskUpdateRequ } // Update service hook fields + h.delay = task.ShutdownDelay h.taskEnv = req.TaskEnv h.services = task.Services h.networks = networks @@ -122,10 +126,35 @@ func (h *serviceHook) Update(ctx context.Context, req *interfaces.TaskUpdateRequ return h.consul.UpdateTask(oldTaskServices, newTaskServices) } +func (h *serviceHook) Killing(ctx context.Context, req *interfaces.TaskKillRequest, resp *interfaces.TaskKillResponse) error { + h.mu.Lock() + defer h.mu.Unlock() + + // Deregister before killing task + h.deregister() + + // If there's no shutdown delay, exit early + if h.delay == 0 { + return nil + } + + h.logger.Debug("waiting before killing task", "shutdown_delay", h.delay) + select { + case <-ctx.Done(): + case <-time.After(h.delay): + } + return nil +} + func (h *serviceHook) Exited(context.Context, *interfaces.TaskExitedRequest, *interfaces.TaskExitedResponse) error { h.mu.Lock() defer h.mu.Unlock() + h.deregister() + return nil +} +// deregister services from Consul. +func (h *serviceHook) deregister() { taskServices := h.getTaskServices() h.consul.RemoveTask(taskServices) @@ -134,7 +163,6 @@ func (h *serviceHook) Exited(context.Context, *interfaces.TaskExitedRequest, *in taskServices.Canary = !taskServices.Canary h.consul.RemoveTask(taskServices) - return nil } func (h *serviceHook) getTaskServices() *agentconsul.TaskServices { diff --git a/client/allocrunner/taskrunner/shutdown_delay_hook.go b/client/allocrunner/taskrunner/shutdown_delay_hook.go deleted file mode 100644 index 15f76ebd169..00000000000 --- a/client/allocrunner/taskrunner/shutdown_delay_hook.go +++ /dev/null @@ -1,36 +0,0 @@ -package taskrunner - -import ( - "context" - "time" - - log "github.com/hashicorp/go-hclog" - "github.com/hashicorp/nomad/client/allocrunner/interfaces" -) - -// shutdownDelayHook delays shutting down a task between deregistering it from -// Consul and actually killing it. -type shutdownDelayHook struct { - delay time.Duration - logger log.Logger -} - -func newShutdownDelayHook(delay time.Duration, logger log.Logger) *shutdownDelayHook { - h := &shutdownDelayHook{ - delay: delay, - } - h.logger = logger.Named(h.Name()) - return h -} - -func (*shutdownDelayHook) Name() string { - return "shutdown-delay" -} - -func (h *shutdownDelayHook) Kill(ctx context.Context, req *interfaces.TaskKillRequest, resp *interfaces.TaskKillResponse) error { - select { - case <-ctx.Done(): - case <-time.After(h.delay): - } - return nil -} diff --git a/client/allocrunner/taskrunner/task_runner.go b/client/allocrunner/taskrunner/task_runner.go index 8f29d7c00a2..b90f45457b7 100644 --- a/client/allocrunner/taskrunner/task_runner.go +++ b/client/allocrunner/taskrunner/task_runner.go @@ -78,12 +78,18 @@ type TaskRunner struct { // stateDB is for persisting localState and taskState stateDB cstate.StateDB - // ctx is the task runner's context representing the tasks's lifecycle. - // Canceling the context will cause the task to be destroyed. + // killCtx is the task runner's context representing the tasks's lifecycle. + // The context is canceled when the task is killed. + killCtx context.Context + + // killCtxCancel is called when killing a task. + killCtxCancel context.CancelFunc + + // ctx is used to exit the TaskRunner *without* affecting task state. ctx context.Context - // ctxCancel is used to exit the task runner's Run loop without - // stopping the task. Shutdown hooks are run. + // ctxCancel causes the TaskRunner to exit immediately without + // affecting task state. Useful for testing or graceful agent shutdown. ctxCancel context.CancelFunc // Logger is the logger for the task runner. @@ -168,8 +174,8 @@ type Config struct { TaskDir *allocdir.TaskDir Logger log.Logger - // VaultClient is the client to use to derive and renew Vault tokens - VaultClient vaultclient.VaultClient + // Vault is the client to use to derive and renew Vault tokens + Vault vaultclient.VaultClient // StateDB is used to store and restore state. StateDB cstate.StateDB @@ -183,9 +189,12 @@ type Config struct { } func NewTaskRunner(config *Config) (*TaskRunner, error) { - // Create a context for the runner + // Create a context for causing the runner to exit trCtx, trCancel := context.WithCancel(context.Background()) + // Create a context for killing the runner + killCtx, killCancel := context.WithCancel(context.Background()) + // Initialize the environment builder envBuilder := env.NewBuilder( config.ClientConfig.Node, @@ -210,11 +219,13 @@ func NewTaskRunner(config *Config) (*TaskRunner, error) { taskLeader: config.Task.Leader, envBuilder: envBuilder, consulClient: config.Consul, - vaultClient: config.VaultClient, + vaultClient: config.Vault, state: tstate, localState: state.NewLocalState(), stateDB: config.StateDB, stateUpdater: config.StateUpdater, + killCtx: killCtx, + killCtxCancel: killCancel, ctx: trCtx, ctxCancel: trCancel, triggerUpdateCh: make(chan struct{}, triggerUpdateChCap), @@ -299,7 +310,16 @@ func (tr *TaskRunner) Run() { go tr.handleUpdates() MAIN: - for tr.ctx.Err() == nil { + for { + select { + case <-tr.killCtx.Done(): + break MAIN + case <-tr.ctx.Done(): + // TaskRunner was told to exit immediately + return + default: + } + // Run the prestart hooks if err := tr.prestart(); err != nil { tr.logger.Error("prestart failed", "error", err) @@ -307,8 +327,13 @@ MAIN: goto RESTART } - if tr.ctx.Err() != nil { + select { + case <-tr.killCtx.Done(): break MAIN + case <-tr.ctx.Done(): + // TaskRunner was told to exit immediately + return + default: } // Run the task @@ -327,12 +352,19 @@ MAIN: { handle := tr.getDriverHandle() - // Do *not* use tr.ctx here as it would cause Wait() to - // unblock before the task exits when Kill() is called. + // Do *not* use tr.killCtx here as it would cause + // Wait() to unblock before the task exits when Kill() + // is called. if resultCh, err := handle.WaitCh(context.Background()); err != nil { tr.logger.Error("wait task failed", "error", err) } else { - result = <-resultCh + select { + case result = <-resultCh: + // WaitCh returned a result + case <-tr.ctx.Done(): + // TaskRunner was told to exit immediately + return + } } } @@ -355,9 +387,12 @@ MAIN: // Actually restart by sleeping and also watching for destroy events select { case <-time.After(restartDelay): - case <-tr.ctx.Done(): + case <-tr.killCtx.Done(): tr.logger.Trace("task killed between restarts", "delay", restartDelay) break MAIN + case <-tr.ctx.Done(): + // TaskRunner was told to exit immediately + return } } @@ -444,7 +479,20 @@ func (tr *TaskRunner) runDriver() error { //TODO mounts and devices //XXX Evaluate and encode driver config - // Start the job + // If there's already a task handle (eg from a Restore) there's nothing + // to do except update state. + if tr.getDriverHandle() != nil { + // Ensure running state is persisted but do *not* append a new + // task event as restoring is a client event and not relevant + // to a task's lifecycle. + if err := tr.updateStateImpl(structs.TaskStateRunning); err != nil { + //TODO return error and destroy task to avoid an orphaned task? + tr.logger.Warn("error persisting task state", "error", err) + } + return nil + } + + // Start the job if there's no existing handle (or if RecoverTask failed) handle, net, err := tr.driver.StartTask(taskConfig) if err != nil { return fmt.Errorf("driver start failed: %v", err) @@ -452,9 +500,18 @@ func (tr *TaskRunner) runDriver() error { tr.localStateLock.Lock() tr.localState.TaskHandle = handle + tr.localState.DriverNetwork = net + if err := tr.stateDB.PutTaskRunnerLocalState(tr.allocID, tr.taskName, tr.localState); err != nil { + //TODO Nomad will be unable to restore this task; try to kill + // it now and fail? In general we prefer to leave running + // tasks running even if the agent encounters an error. + tr.logger.Warn("error persisting local task state; may be unable to restore after a Nomad restart", + "error", err, "task_id", handle.Config.ID) + } tr.localStateLock.Unlock() tr.setDriverHandle(NewDriverHandle(tr.driver, taskConfig.ID, tr.Task(), net)) + // Emit an event that we started tr.UpdateState(structs.TaskStateRunning, structs.NewTaskEvent(structs.TaskStarted)) return nil @@ -525,17 +582,17 @@ func (tr *TaskRunner) initDriver() error { return nil } -// handleDestroy kills the task handle. In the case that killing fails, -// handleDestroy will retry with an exponential backoff and will give up at a -// given limit. It returns whether the task was destroyed and the error -// associated with the last kill attempt. -func (tr *TaskRunner) handleDestroy(handle *DriverHandle) (destroyed bool, err error) { +// killTask kills the task handle. In the case that killing fails, +// killTask will retry with an exponential backoff and will give up at a +// given limit. Returns an error if the task could not be killed. +func (tr *TaskRunner) killTask(handle *DriverHandle) error { // Cap the number of times we attempt to kill the task. + var err error for i := 0; i < killFailureLimit; i++ { if err = handle.Kill(); err != nil { if err == drivers.ErrTaskNotFound { tr.logger.Warn("couldn't find task to kill", "task_id", handle.ID()) - return true, nil + return nil } // Calculate the new backoff backoff := (1 << (2 * uint64(i))) * killBackoffBaseline @@ -547,10 +604,10 @@ func (tr *TaskRunner) handleDestroy(handle *DriverHandle) (destroyed bool, err e time.Sleep(backoff) } else { // Kill was successful - return true, nil + return nil } } - return + return err } // persistLocalState persists local state to disk synchronously. @@ -591,39 +648,84 @@ func (tr *TaskRunner) Restore() error { ls.Canonicalize() tr.localState = ls } + if ts != nil { ts.Canonicalize() tr.state = ts } + + // If a TaskHandle was persisted, ensure it is valid or destroy it. + if taskHandle := tr.localState.TaskHandle; taskHandle != nil { + //TODO if RecoverTask returned the DriverNetwork we wouldn't + // have to persist it at all! + tr.restoreHandle(taskHandle, tr.localState.DriverNetwork) + } return nil } +// restoreHandle ensures a TaskHandle is valid by calling Driver.RecoverTask +// and sets the driver handle. If the TaskHandle is not valid, DestroyTask is +// called. +func (tr *TaskRunner) restoreHandle(taskHandle *drivers.TaskHandle, net *cstructs.DriverNetwork) { + // Ensure handle is well-formed + if taskHandle.Config == nil { + return + } + + if err := tr.driver.RecoverTask(taskHandle); err != nil { + tr.logger.Error("error recovering task; destroying and restarting", + "error", err, "task_id", taskHandle.Config.ID) + + // Try to cleanup any existing task state in the plugin before restarting + if err := tr.driver.DestroyTask(taskHandle.Config.ID, true); err != nil { + // Ignore ErrTaskNotFound errors as ideally + // this task has already been stopped and + // therefore doesn't exist. + if err != drivers.ErrTaskNotFound { + tr.logger.Warn("error destroying unrecoverable task", + "error", err, "task_id", taskHandle.Config.ID) + } + + } + + return + } + + // Update driver handle on task runner + tr.setDriverHandle(NewDriverHandle(tr.driver, taskHandle.Config.ID, tr.Task(), net)) + return +} + // UpdateState sets the task runners allocation state and triggers a server // update. func (tr *TaskRunner) UpdateState(state string, event *structs.TaskEvent) { + tr.stateLock.Lock() + defer tr.stateLock.Unlock() + tr.logger.Trace("setting task state", "state", state, "event", event.Type) - // Update the local state - tr.setStateLocal(state, event) + // Append the event + tr.appendEvent(event) + + // Update the state + if err := tr.updateStateImpl(state); err != nil { + // Only log the error as we persistence errors should not + // affect task state. + tr.logger.Error("error persisting task state", "error", err, "event", event, "state", state) + } // Notify the alloc runner of the transition tr.stateUpdater.TaskStateUpdated() } -// setStateLocal updates the local in-memory state, persists a copy to disk and returns a -// copy of the task's state. -func (tr *TaskRunner) setStateLocal(state string, event *structs.TaskEvent) { - tr.stateLock.Lock() - defer tr.stateLock.Unlock() +// updateStateImpl updates the in-memory task state and persists to disk. +func (tr *TaskRunner) updateStateImpl(state string) error { // Update the task state oldState := tr.state.State taskState := tr.state taskState.State = state - // Append the event - tr.appendEvent(event) - // Handle the state transition. switch state { case structs.TaskStateRunning: @@ -662,11 +764,7 @@ func (tr *TaskRunner) setStateLocal(state string, event *structs.TaskEvent) { } // Persist the state and event - if err := tr.stateDB.PutTaskState(tr.allocID, tr.taskName, taskState); err != nil { - // Only a warning because the next event/state-transition will - // try to persist it again. - tr.logger.Error("error persisting task state", "error", err, "event", event, "state", state) - } + return tr.stateDB.PutTaskState(tr.allocID, tr.taskName, taskState) } // EmitEvent appends a new TaskEvent to this task's TaskState. The actual diff --git a/client/allocrunner/taskrunner/task_runner_hooks.go b/client/allocrunner/taskrunner/task_runner_hooks.go index 4eaecc81b59..40e157a5f08 100644 --- a/client/allocrunner/taskrunner/task_runner_hooks.go +++ b/client/allocrunner/taskrunner/task_runner_hooks.go @@ -26,7 +26,6 @@ func (tr *TaskRunner) initHooks() { newLogMonHook(tr.logmonHookConfig, hookLogger), newDispatchHook(tr.Alloc(), hookLogger), newArtifactHook(tr, hookLogger), - newShutdownDelayHook(task.ShutdownDelay, hookLogger), newStatsHook(tr, tr.clientConfig.StatsCollectionInterval, hookLogger), } @@ -123,7 +122,7 @@ func (tr *TaskRunner) prestart() error { // Run the prestart hook var resp interfaces.TaskPrestartResponse - if err := pre.Prestart(tr.ctx, &req, &resp); err != nil { + if err := pre.Prestart(tr.killCtx, &req, &resp); err != nil { return structs.WrapRecoverable(fmt.Sprintf("prestart hook %q failed: %v", name, err), err) } @@ -195,7 +194,7 @@ func (tr *TaskRunner) poststart() error { TaskEnv: tr.envBuilder.Build(), } var resp interfaces.TaskPoststartResponse - if err := post.Poststart(tr.ctx, &req, &resp); err != nil { + if err := post.Poststart(tr.killCtx, &req, &resp); err != nil { merr.Errors = append(merr.Errors, fmt.Errorf("poststart hook %q failed: %v", name, err)) } @@ -237,7 +236,7 @@ func (tr *TaskRunner) exited() error { req := interfaces.TaskExitedRequest{} var resp interfaces.TaskExitedResponse - if err := post.Exited(tr.ctx, &req, &resp); err != nil { + if err := post.Exited(tr.killCtx, &req, &resp); err != nil { merr.Errors = append(merr.Errors, fmt.Errorf("exited hook %q failed: %v", name, err)) } @@ -280,7 +279,7 @@ func (tr *TaskRunner) stop() error { req := interfaces.TaskStopRequest{} var resp interfaces.TaskStopResponse - if err := post.Stop(tr.ctx, &req, &resp); err != nil { + if err := post.Stop(tr.killCtx, &req, &resp); err != nil { merr.Errors = append(merr.Errors, fmt.Errorf("stop hook %q failed: %v", name, err)) } @@ -336,7 +335,7 @@ func (tr *TaskRunner) updateHooks() { // Run the update hook var resp interfaces.TaskUpdateResponse - if err := upd.Update(tr.ctx, &req, &resp); err != nil { + if err := upd.Update(tr.killCtx, &req, &resp); err != nil { tr.logger.Error("update hook failed", "name", name, "error", err) } @@ -349,8 +348,8 @@ func (tr *TaskRunner) updateHooks() { } } -// kill is used to run the runners kill hooks. -func (tr *TaskRunner) kill() { +// killing is used to run the runners kill hooks. +func (tr *TaskRunner) killing() { if tr.logger.IsTrace() { start := time.Now() tr.logger.Trace("running kill hooks", "start", start) @@ -378,7 +377,7 @@ func (tr *TaskRunner) kill() { // Run the update hook req := interfaces.TaskKillRequest{} var resp interfaces.TaskKillResponse - if err := upd.Kill(context.Background(), &req, &resp); err != nil { + if err := upd.Killing(context.Background(), &req, &resp); err != nil { tr.logger.Error("kill hook failed", "name", name, "error", err) } diff --git a/client/allocrunner/taskrunner/task_runner_test.go b/client/allocrunner/taskrunner/task_runner_test.go new file mode 100644 index 00000000000..f7da0f7e368 --- /dev/null +++ b/client/allocrunner/taskrunner/task_runner_test.go @@ -0,0 +1,157 @@ +package taskrunner + +import ( + "context" + "fmt" + "path/filepath" + "testing" + "time" + + "github.com/hashicorp/nomad/client/allocdir" + "github.com/hashicorp/nomad/client/config" + consulapi "github.com/hashicorp/nomad/client/consul" + cstate "github.com/hashicorp/nomad/client/state" + "github.com/hashicorp/nomad/client/vaultclient" + "github.com/hashicorp/nomad/helper/testlog" + "github.com/hashicorp/nomad/nomad/mock" + "github.com/hashicorp/nomad/nomad/structs" + "github.com/hashicorp/nomad/plugins/shared/catalog" + "github.com/hashicorp/nomad/plugins/shared/singleton" + "github.com/hashicorp/nomad/testutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type MockTaskStateUpdater struct { + ch chan struct{} +} + +func NewMockTaskStateUpdater() *MockTaskStateUpdater { + return &MockTaskStateUpdater{ + ch: make(chan struct{}, 1), + } +} + +func (m *MockTaskStateUpdater) TaskStateUpdated() { + select { + case m.ch <- struct{}{}: + default: + } +} + +// testTaskRunnerConfig returns a taskrunner.Config for the given alloc+task +// plus a cleanup func. +func testTaskRunnerConfig(t *testing.T, alloc *structs.Allocation, taskName string) (*Config, func()) { + logger := testlog.HCLogger(t) + pluginLoader := catalog.TestPluginLoader(t) + clientConf, cleanup := config.TestClientConfig(t) + + // Find the task + var thisTask *structs.Task + for _, tg := range alloc.Job.TaskGroups { + for _, task := range tg.Tasks { + if task.Name == taskName { + if thisTask != nil { + cleanup() + t.Fatalf("multiple tasks named %q; cannot use this helper", taskName) + } + thisTask = task + } + } + } + if thisTask == nil { + cleanup() + t.Fatalf("could not find task %q", taskName) + } + + // Create the alloc dir + task dir + allocPath := filepath.Join(clientConf.AllocDir, alloc.ID) + allocDir := allocdir.NewAllocDir(logger, allocPath) + if err := allocDir.Build(); err != nil { + cleanup() + t.Fatalf("error building alloc dir: %v", err) + } + taskDir := allocDir.NewTaskDir(taskName) + + trCleanup := func() { + if err := allocDir.Destroy(); err != nil { + t.Logf("error destroying alloc dir: %v", err) + } + cleanup() + } + + conf := &Config{ + Alloc: alloc, + ClientConfig: clientConf, + Consul: consulapi.NewMockConsulServiceClient(t, logger), + Task: thisTask, + TaskDir: taskDir, + Logger: clientConf.Logger, + Vault: vaultclient.NewMockVaultClient(), + StateDB: cstate.NoopDB{}, + StateUpdater: NewMockTaskStateUpdater(), + PluginSingletonLoader: singleton.NewSingletonLoader(logger, pluginLoader), + } + return conf, trCleanup +} + +// TestTaskRunner_Restore asserts restoring a running task does not rerun the +// task. +func TestTaskRunner_Restore_Running(t *testing.T) { + t.Parallel() + require := require.New(t) + + alloc := mock.BatchAlloc() + alloc.Job.TaskGroups[0].Count = 1 + task := alloc.Job.TaskGroups[0].Tasks[0] + task.Name = "testtask" + task.Driver = "mock_driver" + task.Config = map[string]interface{}{ + "run_for": 2 * time.Second, + } + conf, cleanup := testTaskRunnerConfig(t, alloc, "testtask") + conf.StateDB = cstate.NewMemDB() // "persist" state between task runners + defer cleanup() + + // Run the first TaskRunner + origTR, err := NewTaskRunner(conf) + require.NoError(err) + go origTR.Run() + defer origTR.Kill(context.Background(), structs.NewTaskEvent("cleanup")) + + // Wait for it to be running + testutil.WaitForResult(func() (bool, error) { + ts := origTR.TaskState() + return ts.State == structs.TaskStateRunning, fmt.Errorf("%v", ts.State) + }, func(err error) { + t.Fatalf("expected running; got: %v", err) + }) + + // Cause TR to exit without shutting down task + origTR.ctxCancel() + <-origTR.WaitCh() + + // Start a new TaskRunner and make sure it does not rerun the task + newTR, err := NewTaskRunner(conf) + require.NoError(err) + + // Do the Restore + require.NoError(newTR.Restore()) + + go newTR.Run() + defer newTR.Kill(context.Background(), structs.NewTaskEvent("cleanup")) + + // Wait for new task runner to exit when the process does + <-newTR.WaitCh() + + // Assert that the process was only started once + started := 0 + state := newTR.TaskState() + require.Equal(structs.TaskStateDead, state.State) + for _, ev := range state.Events { + if ev.Type == structs.TaskStarted { + started++ + } + } + assert.Equal(t, 1, started) +} diff --git a/client/config/testing.go b/client/config/testing.go index 8281938ff3b..73ab82d2d46 100644 --- a/client/config/testing.go +++ b/client/config/testing.go @@ -6,6 +6,7 @@ import ( "path/filepath" "github.com/hashicorp/nomad/helper" + "github.com/hashicorp/nomad/helper/testlog" "github.com/hashicorp/nomad/nomad/structs" "github.com/mitchellh/go-testing-interface" ) @@ -14,6 +15,7 @@ import ( // a cleanup func to remove the state and alloc dirs when finished. func TestClientConfig(t testing.T) (*Config, func()) { conf := DefaultConfig() + conf.Logger = testlog.HCLogger(t) // Create a tempdir to hold state and alloc subdirs parent, err := ioutil.TempDir("", "nomadtest") diff --git a/command/agent/consul/int_test.go b/command/agent/consul/int_test.go index 405a21854f5..65dab80fa97 100644 --- a/command/agent/consul/int_test.go +++ b/command/agent/consul/int_test.go @@ -20,6 +20,8 @@ import ( "github.com/hashicorp/nomad/helper/testlog" "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/structs" + "github.com/hashicorp/nomad/plugins/shared/catalog" + "github.com/hashicorp/nomad/plugins/shared/singleton" "github.com/stretchr/testify/require" ) @@ -143,16 +145,18 @@ func TestConsul_Integration(t *testing.T) { }() // Build the config + pluginLoader := catalog.TestPluginLoader(t) config := &taskrunner.Config{ - Alloc: alloc, - ClientConfig: conf, - Consul: serviceClient, - Task: task, - TaskDir: taskDir, - Logger: logger, - VaultClient: vclient, - StateDB: state.NoopDB{}, - StateUpdater: logUpdate, + Alloc: alloc, + ClientConfig: conf, + Consul: serviceClient, + Task: task, + TaskDir: taskDir, + Logger: logger, + Vault: vclient, + StateDB: state.NoopDB{}, + StateUpdater: logUpdate, + PluginSingletonLoader: singleton.NewSingletonLoader(logger, pluginLoader), } tr, err := taskrunner.NewTaskRunner(config) diff --git a/drivers/mock/driver.go b/drivers/mock/driver.go index 04a0c91cba5..734ae61f8da 100644 --- a/drivers/mock/driver.go +++ b/drivers/mock/driver.go @@ -273,9 +273,23 @@ func (d *Driver) buildFingerprint() *drivers.Fingerprint { } } -func (d *Driver) RecoverTask(*drivers.TaskHandle) error { - //TODO is there anything to do here? - return nil +func (d *Driver) RecoverTask(h *drivers.TaskHandle) error { + if h == nil { + return fmt.Errorf("handle cannot be nil") + } + + if _, ok := d.tasks.Get(h.Config.ID); ok { + d.logger.Debug("nothing to recover; task already exists", + "task_id", h.Config.ID, + "task_name", h.Config.Name, + ) + return nil + } + + // Recovering a task requires the task to be running external to the + // plugin. Since the mock_driver runs all tasks in process it cannot + // recover tasks. + return fmt.Errorf("%s cannot recover tasks", pluginName) } func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *cstructs.DriverNetwork, error) { diff --git a/drivers/qemu/driver.go b/drivers/qemu/driver.go index fde8db8e11f..9c9a3a3fe5e 100644 --- a/drivers/qemu/driver.go +++ b/drivers/qemu/driver.go @@ -244,6 +244,15 @@ func (d *Driver) RecoverTask(handle *drivers.TaskHandle) error { return fmt.Errorf("error: handle cannot be nil") } + // If already attached to handle there's nothing to recover. + if _, ok := d.tasks.Get(handle.Config.ID); ok { + d.logger.Trace("nothing to recover; task already exists", + "task_id", handle.Config.ID, + "task_name", handle.Config.Name, + ) + return nil + } + var taskState TaskState if err := handle.GetDriverState(&taskState); err != nil { d.logger.Error("failed to decode taskConfig state from handle", "error", err, "task_id", handle.Config.ID) diff --git a/drivers/rawexec/driver.go b/drivers/rawexec/driver.go index 6c75a79de73..fefd3190dc9 100644 --- a/drivers/rawexec/driver.go +++ b/drivers/rawexec/driver.go @@ -242,9 +242,19 @@ func (d *Driver) buildFingerprint() *drivers.Fingerprint { func (d *Driver) RecoverTask(handle *drivers.TaskHandle) error { if handle == nil { - return fmt.Errorf("error: handle cannot be nil") + return fmt.Errorf("handle cannot be nil") } + // If already attached to handle there's nothing to recover. + if _, ok := d.tasks.Get(handle.Config.ID); ok { + d.logger.Trace("nothing to recover; task already exists", + "task_id", handle.Config.ID, + "task_name", handle.Config.Name, + ) + return nil + } + + // Handle doesn't already exist, try to reattach var taskState TaskState if err := handle.GetDriverState(&taskState); err != nil { d.logger.Error("failed to decode task state from handle", "error", err, "task_id", handle.Config.ID) @@ -261,6 +271,7 @@ func (d *Driver) RecoverTask(handle *drivers.TaskHandle) error { Reattach: plugRC, } + // Create client for reattached executor exec, pluginClient, err := utils.CreateExecutorWithConfig(pluginConfig, os.Stderr) if err != nil { d.logger.Error("failed to reattach to executor", "error", err, "task_id", handle.Config.ID) diff --git a/drivers/rkt/driver.go b/drivers/rkt/driver.go index cd5919e0629..e7f89c75830 100644 --- a/drivers/rkt/driver.go +++ b/drivers/rkt/driver.go @@ -317,6 +317,15 @@ func (d *Driver) RecoverTask(handle *drivers.TaskHandle) error { return fmt.Errorf("error: handle cannot be nil") } + // If already attached to handle there's nothing to recover. + if _, ok := d.tasks.Get(handle.Config.ID); ok { + d.logger.Trace("nothing to recover; task already exists", + "task_id", handle.Config.ID, + "task_name", handle.Config.Name, + ) + return nil + } + var taskState TaskState if err := handle.GetDriverState(&taskState); err != nil { d.logger.Error("failed to decode taskConfig state from handle", "error", err, "task_id", handle.Config.ID) diff --git a/plugins/drivers/driver.go b/plugins/drivers/driver.go index f37a10ff00d..2bb7267c4b1 100644 --- a/plugins/drivers/driver.go +++ b/plugins/drivers/driver.go @@ -46,7 +46,7 @@ type DriverPlugin interface { // DriverPlugin interface. type DriverSignalTaskNotSupported struct{} -func (_ DriverSignalTaskNotSupported) SignalTask(taskID, signal string) error { +func (DriverSignalTaskNotSupported) SignalTask(taskID, signal string) error { return fmt.Errorf("SignalTask is not supported by this driver") } diff --git a/plugins/drivers/plugin_test.go b/plugins/drivers/plugin_test.go index ce15d074b0a..3409124e8cb 100644 --- a/plugins/drivers/plugin_test.go +++ b/plugins/drivers/plugin_test.go @@ -103,7 +103,7 @@ func TestBaseDriver_RecoverTask(t *testing.T) { defer harness.Kill() handle := &TaskHandle{ - driverState: buf.Bytes(), + DriverState: buf.Bytes(), } err := harness.RecoverTask(handle) require.NoError(err) diff --git a/plugins/drivers/task_handle.go b/plugins/drivers/task_handle.go index 749c1031485..ea538c5778f 100644 --- a/plugins/drivers/task_handle.go +++ b/plugins/drivers/task_handle.go @@ -11,7 +11,7 @@ type TaskHandle struct { Driver string Config *TaskConfig State TaskState - driverState []byte + DriverState []byte } func NewTaskHandle(driver string) *TaskHandle { @@ -19,12 +19,12 @@ func NewTaskHandle(driver string) *TaskHandle { } func (h *TaskHandle) SetDriverState(v interface{}) error { - h.driverState = []byte{} - return base.MsgPackEncode(&h.driverState, v) + h.DriverState = []byte{} + return base.MsgPackEncode(&h.DriverState, v) } func (h *TaskHandle) GetDriverState(v interface{}) error { - return base.MsgPackDecode(h.driverState, v) + return base.MsgPackDecode(h.DriverState, v) } @@ -34,7 +34,10 @@ func (h *TaskHandle) Copy() *TaskHandle { } handle := new(TaskHandle) - *handle = *h + handle.Driver = h.Driver handle.Config = h.Config.Copy() + handle.State = h.State + handle.DriverState = make([]byte, len(h.DriverState)) + copy(handle.DriverState, h.DriverState) return handle } diff --git a/plugins/drivers/utils.go b/plugins/drivers/utils.go index db37561879b..4caf3da7f2c 100644 --- a/plugins/drivers/utils.go +++ b/plugins/drivers/utils.go @@ -194,7 +194,7 @@ func taskHandleFromProto(pb *proto.TaskHandle) *TaskHandle { return &TaskHandle{ Config: taskConfigFromProto(pb.Config), State: taskStateFromProtoMap[pb.State], - driverState: pb.DriverState, + DriverState: pb.DriverState, } } @@ -202,7 +202,7 @@ func taskHandleToProto(handle *TaskHandle) *proto.TaskHandle { return &proto.TaskHandle{ Config: taskConfigToProto(handle.Config), State: taskStateToProtoMap[handle.State], - DriverState: handle.driverState, + DriverState: handle.DriverState, } }