diff --git a/api/tasks.go b/api/tasks.go index 578188cc226..f9ad7856bb4 100644 --- a/api/tasks.go +++ b/api/tasks.go @@ -411,6 +411,7 @@ type TaskGroup struct { Networks []*NetworkResource Meta map[string]string Services []*Service + ShutdownDelay *time.Duration `mapstructure:"shutdown_delay"` } // NewTaskGroup creates a new TaskGroup. diff --git a/client/allocrunner/alloc_runner.go b/client/allocrunner/alloc_runner.go index a8291cc7199..eee412cc725 100644 --- a/client/allocrunner/alloc_runner.go +++ b/client/allocrunner/alloc_runner.go @@ -500,6 +500,9 @@ func (ar *allocRunner) killTasks() map[string]*structs.TaskState { var mu sync.Mutex states := make(map[string]*structs.TaskState, len(ar.tasks)) + // run alloc prekill hooks + ar.preKillHooks() + // Kill leader first, synchronously for name, tr := range ar.tasks { if !tr.IsLeader() { diff --git a/client/allocrunner/alloc_runner_hooks.go b/client/allocrunner/alloc_runner_hooks.go index 63dc78f49f6..c55ed384e4f 100644 --- a/client/allocrunner/alloc_runner_hooks.go +++ b/client/allocrunner/alloc_runner_hooks.go @@ -295,6 +295,29 @@ func (ar *allocRunner) destroy() error { return merr.ErrorOrNil() } +func (ar *allocRunner) preKillHooks() { + for _, hook := range ar.runnerHooks { + pre, ok := hook.(interfaces.RunnerPreKillHook) + if !ok { + continue + } + + name := pre.Name() + var start time.Time + if ar.logger.IsTrace() { + start = time.Now() + ar.logger.Trace("running alloc pre shutdown hook", "name", name, "start", start) + } + + pre.PreKill() + + if ar.logger.IsTrace() { + end := time.Now() + ar.logger.Trace("finished alloc pre shutdown hook", "name", name, "end", end, "duration", end.Sub(start)) + } + } +} + // shutdownHooks calls graceful shutdown hooks for when the agent is exiting. func (ar *allocRunner) shutdownHooks() { for _, hook := range ar.runnerHooks { diff --git a/client/allocrunner/alloc_runner_test.go b/client/allocrunner/alloc_runner_test.go index f4ee44a65be..8aaba67b4d5 100644 --- a/client/allocrunner/alloc_runner_test.go +++ b/client/allocrunner/alloc_runner_test.go @@ -141,6 +141,101 @@ func TestAllocRunner_TaskLeader_KillTG(t *testing.T) { }) } +func TestAllocRunner_TaskGroup_ShutdownDelay(t *testing.T) { + t.Parallel() + shutdownDelay := 1 * time.Second + + alloc := mock.Alloc() + tr := alloc.AllocatedResources.Tasks[alloc.Job.TaskGroups[0].Tasks[0].Name] + alloc.Job.TaskGroups[0].RestartPolicy.Attempts = 0 + + // Create 3 tasks in the task group + task := alloc.Job.TaskGroups[0].Tasks[0] + task.Name = "follower1" + task.Driver = "mock_driver" + task.Config = map[string]interface{}{ + "run_for": "10s", + } + + task2 := alloc.Job.TaskGroups[0].Tasks[0].Copy() + task2.Name = "leader" + task2.Driver = "mock_driver" + task2.Leader = true + task2.Config = map[string]interface{}{ + "run_for": "10s", + } + + alloc.Job.TaskGroups[0].Tasks = append(alloc.Job.TaskGroups[0].Tasks, task2) + alloc.AllocatedResources.Tasks[task.Name] = tr + alloc.AllocatedResources.Tasks[task2.Name] = tr + + // Set a shutdown delay + alloc.Job.TaskGroups[0].ShutdownDelay = &shutdownDelay + + conf, cleanup := testAllocRunnerConfig(t, alloc) + defer cleanup() + ar, err := NewAllocRunner(conf) + require.NoError(t, err) + defer destroy(ar) + go ar.Run() + + // Wait for tasks to start + upd := conf.StateUpdater.(*MockStateUpdater) + last := upd.Last() + testutil.WaitForResult(func() (bool, error) { + last = upd.Last() + if last == nil { + return false, fmt.Errorf("No updates") + } + if n := len(last.TaskStates); n != 2 { + return false, fmt.Errorf("Not enough task states (want: 2; found %d)", n) + } + for name, state := range last.TaskStates { + if state.State != structs.TaskStateRunning { + return false, fmt.Errorf("Task %q is not running yet (it's %q)", name, state.State) + } + } + return true, nil + }, func(err error) { + t.Fatalf("err: %v", err) + }) + + // Reset updates + upd.Reset() + + // Stop alloc + now := time.Now() + update := alloc.Copy() + update.DesiredStatus = structs.AllocDesiredStatusStop + ar.Update(update) + + // Wait for tasks to stop + testutil.WaitForResult(func() (bool, error) { + last := upd.Last() + if last == nil { + return false, fmt.Errorf("No updates") + } + + fin := last.TaskStates["leader"].FinishedAt + + if fin.IsZero() { + return false, nil + } + + return true, nil + }, func(err error) { + last := upd.Last() + for name, state := range last.TaskStates { + t.Logf("%s: %s", name, state.State) + } + t.Fatalf("err: %v", err) + }) + + last = upd.Last() + require.Greater(t, last.TaskStates["leader"].FinishedAt.UnixNano(), now.Add(shutdownDelay).UnixNano()) + require.Greater(t, last.TaskStates["follower1"].FinishedAt.UnixNano(), now.Add(shutdownDelay).UnixNano()) +} + // TestAllocRunner_TaskLeader_StopTG asserts that when stopping an alloc with a // leader the leader is stopped before other tasks. func TestAllocRunner_TaskLeader_StopTG(t *testing.T) { diff --git a/client/allocrunner/groupservice_hook.go b/client/allocrunner/groupservice_hook.go index 1660fc738f5..a02cc9d5b4f 100644 --- a/client/allocrunner/groupservice_hook.go +++ b/client/allocrunner/groupservice_hook.go @@ -2,6 +2,7 @@ package allocrunner import ( "sync" + "time" log "github.com/hashicorp/go-hclog" "github.com/hashicorp/nomad/client/allocrunner/interfaces" @@ -20,6 +21,8 @@ type groupServiceHook struct { restarter agentconsul.WorkloadRestarter consulClient consul.ConsulServiceAPI prerun bool + delay time.Duration + deregistered bool logger log.Logger @@ -43,12 +46,20 @@ type groupServiceHookConfig struct { } func newGroupServiceHook(cfg groupServiceHookConfig) *groupServiceHook { + var shutdownDelay time.Duration + tg := cfg.alloc.Job.LookupTaskGroup(cfg.alloc.TaskGroup) + + if tg != nil && tg.ShutdownDelay != nil { + shutdownDelay = *tg.ShutdownDelay + } + h := &groupServiceHook{ allocID: cfg.alloc.ID, group: cfg.alloc.TaskGroup, restarter: cfg.restarter, consulClient: cfg.consul, taskEnvBuilder: cfg.taskEnvBuilder, + delay: shutdownDelay, } h.logger = cfg.logger.Named(h.Name()) h.services = cfg.alloc.Job.LookupTaskGroup(h.group).Services @@ -117,10 +128,29 @@ func (h *groupServiceHook) Update(req *interfaces.RunnerUpdateRequest) error { return h.consulClient.UpdateWorkload(oldWorkloadServices, newWorkloadServices) } -func (h *groupServiceHook) Postrun() error { +func (h *groupServiceHook) PreKill() { h.mu.Lock() defer h.mu.Unlock() + + // If we have a shutdown delay deregister + // group services and then wait + // before continuing to kill tasks h.deregister() + h.deregistered = true + + h.logger.Debug("waiting before removing group service", "shutdown_delay", h.delay) + select { + case <-time.After(h.delay): + } +} + +func (h *groupServiceHook) Postrun() error { + h.mu.Lock() + defer h.mu.Unlock() + + if !h.deregistered { + h.deregister() + } return nil } diff --git a/client/allocrunner/groupservice_hook_test.go b/client/allocrunner/groupservice_hook_test.go index afd8cb0d118..f479a27b4a0 100644 --- a/client/allocrunner/groupservice_hook_test.go +++ b/client/allocrunner/groupservice_hook_test.go @@ -20,6 +20,7 @@ import ( var _ interfaces.RunnerPrerunHook = (*groupServiceHook)(nil) var _ interfaces.RunnerUpdateHook = (*groupServiceHook)(nil) var _ interfaces.RunnerPostrunHook = (*groupServiceHook)(nil) +var _ interfaces.RunnerPreKillHook = (*groupServiceHook)(nil) // TestGroupServiceHook_NoGroupServices asserts calling group service hooks // without group services does not error. diff --git a/client/allocrunner/interfaces/runner_lifecycle.go b/client/allocrunner/interfaces/runner_lifecycle.go index 8bbaba4af37..f0b2e6c74fa 100644 --- a/client/allocrunner/interfaces/runner_lifecycle.go +++ b/client/allocrunner/interfaces/runner_lifecycle.go @@ -53,3 +53,10 @@ type ShutdownHook interface { Shutdown() } + +// +type RunnerPreKillHook interface { + RunnerHook + + PreKill() +} diff --git a/command/agent/job_endpoint.go b/command/agent/job_endpoint.go index 710c94ec31d..3814c921a0c 100644 --- a/command/agent/job_endpoint.go +++ b/command/agent/job_endpoint.go @@ -696,6 +696,7 @@ func ApiTgToStructsTG(taskGroup *api.TaskGroup, tg *structs.TaskGroup) { tg.Name = *taskGroup.Name tg.Count = *taskGroup.Count tg.Meta = taskGroup.Meta + tg.ShutdownDelay = taskGroup.ShutdownDelay tg.Constraints = ApiConstraintsToStructs(taskGroup.Constraints) tg.Affinities = ApiAffinitiesToStructs(taskGroup.Affinities) tg.Networks = ApiNetworkResourceToStructs(taskGroup.Networks) diff --git a/jobspec/parse_group.go b/jobspec/parse_group.go index 39ebb292c4a..062c9607444 100644 --- a/jobspec/parse_group.go +++ b/jobspec/parse_group.go @@ -51,6 +51,7 @@ func parseGroups(result *api.Job, list *ast.ObjectList) error { "vault", "migrate", "spread", + "shutdown_delay", "network", "service", "volume", @@ -63,6 +64,7 @@ func parseGroups(result *api.Job, list *ast.ObjectList) error { if err := hcl.DecodeObject(&m, item.Val); err != nil { return err } + delete(m, "constraint") delete(m, "affinity") delete(m, "meta") @@ -80,7 +82,16 @@ func parseGroups(result *api.Job, list *ast.ObjectList) error { // Build the group with the basic decode var g api.TaskGroup g.Name = helper.StringToPtr(n) - if err := mapstructure.WeakDecode(m, &g); err != nil { + dec, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ + DecodeHook: mapstructure.StringToTimeDurationHookFunc(), + WeaklyTypedInput: true, + Result: &g, + }) + + if err != nil { + return err + } + if err := dec.Decode(m); err != nil { return err } @@ -201,7 +212,6 @@ func parseGroups(result *api.Job, list *ast.ObjectList) error { return multierror.Prefix(err, fmt.Sprintf("'%s',", n)) } } - collection = append(collection, &g) } diff --git a/jobspec/parse_test.go b/jobspec/parse_test.go index 407666bfd4a..b95cdc93e12 100644 --- a/jobspec/parse_test.go +++ b/jobspec/parse_test.go @@ -926,8 +926,9 @@ func TestParse(t *testing.T) { Datacenters: []string{"dc1"}, TaskGroups: []*api.TaskGroup{ { - Name: helper.StringToPtr("bar"), - Count: helper.IntToPtr(3), + Name: helper.StringToPtr("bar"), + ShutdownDelay: 14 * time.Second, + Count: helper.IntToPtr(3), Networks: []*api.NetworkResource{ { Mode: "bridge", diff --git a/jobspec/test-fixtures/tg-network.hcl b/jobspec/test-fixtures/tg-network.hcl index dadf7eccc82..538f49f3acf 100644 --- a/jobspec/test-fixtures/tg-network.hcl +++ b/jobspec/test-fixtures/tg-network.hcl @@ -2,7 +2,8 @@ job "foo" { datacenters = ["dc1"] group "bar" { - count = 3 + count = 3 + shutdown_delay = "14s" network { mode = "bridge" diff --git a/nomad/structs/structs.go b/nomad/structs/structs.go index cf3bb544383..80253a36352 100644 --- a/nomad/structs/structs.go +++ b/nomad/structs/structs.go @@ -3547,6 +3547,10 @@ func (j *Job) Validate() error { taskGroups[tg.Name] = idx } + // if tg.ShutdownDelay < 0 { + // mErr.Errors = append(mErr.Errors, errors.New("ShutdownDelay must be a positive value")) + // } + if j.Type == "system" && tg.Count > 1 { mErr.Errors = append(mErr.Errors, fmt.Errorf("Job task group %s has count %d. Count cannot exceed 1 with system scheduler", @@ -4736,6 +4740,8 @@ type TaskGroup struct { // Volumes is a map of volumes that have been requested by the task group. Volumes map[string]*VolumeRequest + + ShutdownDelay *time.Duration } func (tg *TaskGroup) Copy() *TaskGroup { @@ -4782,6 +4788,10 @@ func (tg *TaskGroup) Copy() *TaskGroup { } } + if tg.ShutdownDelay != nil { + ntg.ShutdownDelay = helper.TimeToPtr(*tg.ShutdownDelay) + } + return ntg }