diff --git a/CHANGELOG.md b/CHANGELOG.md index b11a4ce6b18..a881fef4972 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ IMPROVEMENTS: * api: Metrics endpoint exposes Prometheus formatted metrics [GH-3171] + * discovery: Allow restarting unhealthy tasks with `check_restart` [GH-3105] * telemetry: Add support for tagged metrics for Nomad clients [GH-3147] * telemetry: Add basic Prometheus configuration for a Nomad cluster [GH-3186] diff --git a/api/tasks.go b/api/tasks.go index 76c1be65889..a3d10831e82 100644 --- a/api/tasks.go +++ b/api/tasks.go @@ -79,6 +79,71 @@ func (r *RestartPolicy) Merge(rp *RestartPolicy) { } } +// CheckRestart describes if and when a task should be restarted based on +// failing health checks. +type CheckRestart struct { + Limit int `mapstructure:"limit"` + Grace *time.Duration `mapstructure:"grace_period"` + IgnoreWarnings bool `mapstructure:"ignore_warnings"` +} + +// Canonicalize CheckRestart fields if not nil. +func (c *CheckRestart) Canonicalize() { + if c == nil { + return + } + + if c.Grace == nil { + c.Grace = helper.TimeToPtr(1 * time.Second) + } +} + +// Copy returns a copy of CheckRestart or nil if unset. +func (c *CheckRestart) Copy() *CheckRestart { + if c == nil { + return nil + } + + nc := new(CheckRestart) + nc.Limit = c.Limit + if c.Grace != nil { + g := *c.Grace + nc.Grace = &g + } + nc.IgnoreWarnings = c.IgnoreWarnings + return nc +} + +// Merge values from other CheckRestart over default values on this +// CheckRestart and return merged copy. +func (c *CheckRestart) Merge(o *CheckRestart) *CheckRestart { + if c == nil { + // Just return other + return o + } + + nc := c.Copy() + + if o == nil { + // Nothing to merge + return nc + } + + if nc.Limit == 0 { + nc.Limit = o.Limit + } + + if nc.Grace == nil { + nc.Grace = o.Grace + } + + if nc.IgnoreWarnings { + nc.IgnoreWarnings = o.IgnoreWarnings + } + + return nc +} + // The ServiceCheck data model represents the consul health check that // Nomad registers for a Task type ServiceCheck struct { @@ -96,16 +161,18 @@ type ServiceCheck struct { TLSSkipVerify bool `mapstructure:"tls_skip_verify"` Header map[string][]string Method string + CheckRestart *CheckRestart `mapstructure:"check_restart"` } // The Service model represents a Consul service definition type Service struct { - Id string - Name string - Tags []string - PortLabel string `mapstructure:"port"` - AddressMode string `mapstructure:"address_mode"` - Checks []ServiceCheck + Id string + Name string + Tags []string + PortLabel string `mapstructure:"port"` + AddressMode string `mapstructure:"address_mode"` + Checks []ServiceCheck + CheckRestart *CheckRestart `mapstructure:"check_restart"` } func (s *Service) Canonicalize(t *Task, tg *TaskGroup, job *Job) { @@ -117,6 +184,15 @@ func (s *Service) Canonicalize(t *Task, tg *TaskGroup, job *Job) { if s.AddressMode == "" { s.AddressMode = "auto" } + + s.CheckRestart.Canonicalize() + + // Canonicallize CheckRestart on Checks and merge Service.CheckRestart + // into each check. + for _, c := range s.Checks { + c.CheckRestart.Canonicalize() + c.CheckRestart = c.CheckRestart.Merge(s.CheckRestart) + } } // EphemeralDisk is an ephemeral disk object diff --git a/client/alloc_runner.go b/client/alloc_runner.go index d6486734e54..c4725c6e4fb 100644 --- a/client/alloc_runner.go +++ b/client/alloc_runner.go @@ -336,7 +336,8 @@ func (r *AllocRunner) RestoreState() error { // Restart task runner if RestoreState gave a reason if restartReason != "" { r.logger.Printf("[INFO] client: restarting alloc %s task %s: %v", r.allocID, name, restartReason) - tr.Restart("upgrade", restartReason) + const failure = false + tr.Restart("upgrade", restartReason, failure) } } else { tr.Destroy(taskDestroyEvent) diff --git a/client/consul.go b/client/consul.go index 89666e41e45..02e40ef0f09 100644 --- a/client/consul.go +++ b/client/consul.go @@ -10,8 +10,8 @@ import ( // ConsulServiceAPI is the interface the Nomad Client uses to register and // remove services and checks from Consul. type ConsulServiceAPI interface { - RegisterTask(allocID string, task *structs.Task, exec driver.ScriptExecutor, net *cstructs.DriverNetwork) error + RegisterTask(allocID string, task *structs.Task, restarter consul.TaskRestarter, exec driver.ScriptExecutor, net *cstructs.DriverNetwork) error RemoveTask(allocID string, task *structs.Task) - UpdateTask(allocID string, existing, newTask *structs.Task, exec driver.ScriptExecutor, net *cstructs.DriverNetwork) error + UpdateTask(allocID string, existing, newTask *structs.Task, restart consul.TaskRestarter, exec driver.ScriptExecutor, net *cstructs.DriverNetwork) error AllocRegistrations(allocID string) (*consul.AllocRegistration, error) } diff --git a/client/consul_template.go b/client/consul_template.go index 2e1b1ac60e1..2f7a629357f 100644 --- a/client/consul_template.go +++ b/client/consul_template.go @@ -49,7 +49,7 @@ var ( // TaskHooks is an interface which provides hooks into the tasks life-cycle type TaskHooks interface { // Restart is used to restart the task - Restart(source, reason string) + Restart(source, reason string, failure bool) // Signal is used to signal the task Signal(source, reason string, s os.Signal) error @@ -439,7 +439,8 @@ func (tm *TaskTemplateManager) handleTemplateRerenders(allRenderedTime time.Time } if restart { - tm.config.Hooks.Restart(consulTemplateSourceName, "template with change_mode restart re-rendered") + const failure = false + tm.config.Hooks.Restart(consulTemplateSourceName, "template with change_mode restart re-rendered", failure) } else if len(signals) != 0 { var mErr multierror.Error for signal := range signals { diff --git a/client/consul_template_test.go b/client/consul_template_test.go index 88ee17b0e50..f8368520407 100644 --- a/client/consul_template_test.go +++ b/client/consul_template_test.go @@ -57,7 +57,7 @@ func NewMockTaskHooks() *MockTaskHooks { EmitEventCh: make(chan struct{}, 1), } } -func (m *MockTaskHooks) Restart(source, reason string) { +func (m *MockTaskHooks) Restart(source, reason string, failure bool) { m.Restarts++ select { case m.RestartCh <- struct{}{}: diff --git a/client/consul_test.go b/client/consul_test.go index 10d1ebe10d0..8703cdd215a 100644 --- a/client/consul_test.go +++ b/client/consul_test.go @@ -60,7 +60,7 @@ func newMockConsulServiceClient() *mockConsulServiceClient { return &m } -func (m *mockConsulServiceClient) UpdateTask(allocID string, old, new *structs.Task, exec driver.ScriptExecutor, net *cstructs.DriverNetwork) error { +func (m *mockConsulServiceClient) UpdateTask(allocID string, old, new *structs.Task, restarter consul.TaskRestarter, exec driver.ScriptExecutor, net *cstructs.DriverNetwork) error { m.mu.Lock() defer m.mu.Unlock() m.logger.Printf("[TEST] mock_consul: UpdateTask(%q, %v, %v, %T, %x)", allocID, old, new, exec, net.Hash()) @@ -68,7 +68,7 @@ func (m *mockConsulServiceClient) UpdateTask(allocID string, old, new *structs.T return nil } -func (m *mockConsulServiceClient) RegisterTask(allocID string, task *structs.Task, exec driver.ScriptExecutor, net *cstructs.DriverNetwork) error { +func (m *mockConsulServiceClient) RegisterTask(allocID string, task *structs.Task, restarter consul.TaskRestarter, exec driver.ScriptExecutor, net *cstructs.DriverNetwork) error { m.mu.Lock() defer m.mu.Unlock() m.logger.Printf("[TEST] mock_consul: RegisterTask(%q, %q, %T, %x)", allocID, task.Name, exec, net.Hash()) diff --git a/client/restarts.go b/client/restarts.go index b6e49e31c70..c403b6f05d8 100644 --- a/client/restarts.go +++ b/client/restarts.go @@ -37,6 +37,7 @@ type RestartTracker struct { waitRes *dstructs.WaitResult startErr error restartTriggered bool // Whether the task has been signalled to be restarted + failure bool // Whether a failure triggered the restart count int // Current number of attempts. onSuccess bool // Whether to restart on successful exit code. startTime time.Time // When the interval began @@ -59,6 +60,7 @@ func (r *RestartTracker) SetStartError(err error) *RestartTracker { r.lock.Lock() defer r.lock.Unlock() r.startErr = err + r.failure = true return r } @@ -67,15 +69,22 @@ func (r *RestartTracker) SetWaitResult(res *dstructs.WaitResult) *RestartTracker r.lock.Lock() defer r.lock.Unlock() r.waitRes = res + r.failure = true return r } // SetRestartTriggered is used to mark that the task has been signalled to be -// restarted -func (r *RestartTracker) SetRestartTriggered() *RestartTracker { +// restarted. Setting the failure to true restarts according to the restart +// policy. When failure is false the task is restarted without considering the +// restart policy. +func (r *RestartTracker) SetRestartTriggered(failure bool) *RestartTracker { r.lock.Lock() defer r.lock.Unlock() - r.restartTriggered = true + if failure { + r.failure = true + } else { + r.restartTriggered = true + } return r } @@ -106,6 +115,7 @@ func (r *RestartTracker) GetState() (string, time.Duration) { r.startErr = nil r.waitRes = nil r.restartTriggered = false + r.failure = false }() // Hot path if a restart was triggered @@ -134,52 +144,29 @@ func (r *RestartTracker) GetState() (string, time.Duration) { r.startTime = now } - if r.startErr != nil { - return r.handleStartError() - } else if r.waitRes != nil { - return r.handleWaitResult() + // Handle restarts due to failures + if !r.failure { + return "", 0 } - return "", 0 -} - -// handleStartError returns the new state and potential wait duration for -// restarting the task after it was not successfully started. On start errors, -// the restart policy is always treated as fail mode to ensure we don't -// infinitely try to start a task. -func (r *RestartTracker) handleStartError() (string, time.Duration) { - // If the error is not recoverable, do not restart. - if !structs.IsRecoverable(r.startErr) { - r.reason = ReasonUnrecoverableErrror - return structs.TaskNotRestarting, 0 - } - - if r.count > r.policy.Attempts { - if r.policy.Mode == structs.RestartPolicyModeFail { - r.reason = fmt.Sprintf( - `Exceeded allowed attempts %d in interval %v and mode is "fail"`, - r.policy.Attempts, r.policy.Interval) + if r.startErr != nil { + // If the error is not recoverable, do not restart. + if !structs.IsRecoverable(r.startErr) { + r.reason = ReasonUnrecoverableErrror return structs.TaskNotRestarting, 0 - } else { - r.reason = ReasonDelay - return structs.TaskRestarting, r.getDelay() + } + } else if r.waitRes != nil { + // If the task started successfully and restart on success isn't specified, + // don't restart but don't mark as failed. + if r.waitRes.Successful() && !r.onSuccess { + r.reason = "Restart unnecessary as task terminated successfully" + return structs.TaskTerminated, 0 } } - r.reason = ReasonWithinPolicy - return structs.TaskRestarting, r.jitter() -} - -// handleWaitResult returns the new state and potential wait duration for -// restarting the task after it has exited. -func (r *RestartTracker) handleWaitResult() (string, time.Duration) { - // If the task started successfully and restart on success isn't specified, - // don't restart but don't mark as failed. - if r.waitRes.Successful() && !r.onSuccess { - r.reason = "Restart unnecessary as task terminated successfully" - return structs.TaskTerminated, 0 - } - + // If this task has been restarted due to failures more times + // than the restart policy allows within an interval fail + // according to the restart policy's mode. if r.count > r.policy.Attempts { if r.policy.Mode == structs.RestartPolicyModeFail { r.reason = fmt.Sprintf( diff --git a/client/restarts_test.go b/client/restarts_test.go index 851052576e6..b0cad5b1a3c 100644 --- a/client/restarts_test.go +++ b/client/restarts_test.go @@ -99,11 +99,24 @@ func TestClient_RestartTracker_RestartTriggered(t *testing.T) { p := testPolicy(true, structs.RestartPolicyModeFail) p.Attempts = 0 rt := newRestartTracker(p, structs.JobTypeService) - if state, when := rt.SetRestartTriggered().GetState(); state != structs.TaskRestarting && when != 0 { + if state, when := rt.SetRestartTriggered(false).GetState(); state != structs.TaskRestarting && when != 0 { t.Fatalf("expect restart immediately, got %v %v", state, when) } } +func TestClient_RestartTracker_RestartTriggered_Failure(t *testing.T) { + t.Parallel() + p := testPolicy(true, structs.RestartPolicyModeFail) + p.Attempts = 1 + rt := newRestartTracker(p, structs.JobTypeService) + if state, when := rt.SetRestartTriggered(true).GetState(); state != structs.TaskRestarting || when == 0 { + t.Fatalf("expect restart got %v %v", state, when) + } + if state, when := rt.SetRestartTriggered(true).GetState(); state != structs.TaskNotRestarting || when != 0 { + t.Fatalf("expect failed got %v %v", state, when) + } +} + func TestClient_RestartTracker_StartError_Recoverable_Fail(t *testing.T) { t.Parallel() p := testPolicy(true, structs.RestartPolicyModeFail) diff --git a/client/task_runner.go b/client/task_runner.go index cd5afbd9197..a5d96726e42 100644 --- a/client/task_runner.go +++ b/client/task_runner.go @@ -65,13 +65,29 @@ var ( taskRunnerStateAllKey = []byte("simple-all") ) +// taskRestartEvent wraps a TaskEvent with additional metadata to control +// restart behavior. +type taskRestartEvent struct { + // taskEvent to report + taskEvent *structs.TaskEvent + + // if false, don't count against restart count + failure bool +} + +func newTaskRestartEvent(reason string, failure bool) *taskRestartEvent { + return &taskRestartEvent{ + taskEvent: structs.NewTaskEvent(structs.TaskRestartSignal).SetRestartReason(reason), + failure: failure, + } +} + // TaskRunner is used to wrap a task within an allocation and provide the execution context. type TaskRunner struct { stateDB *bolt.DB config *config.Config updater TaskStateUpdater logger *log.Logger - alloc *structs.Allocation restartTracker *RestartTracker consul ConsulServiceAPI @@ -82,6 +98,7 @@ type TaskRunner struct { resourceUsage *cstructs.TaskResourceUsage resourceUsageLock sync.RWMutex + alloc *structs.Allocation task *structs.Task taskDir *allocdir.TaskDir @@ -139,7 +156,7 @@ type TaskRunner struct { unblockLock sync.Mutex // restartCh is used to restart a task - restartCh chan *structs.TaskEvent + restartCh chan *taskRestartEvent // signalCh is used to send a signal to a task signalCh chan SignalEvent @@ -247,7 +264,7 @@ func NewTaskRunner(logger *log.Logger, config *config.Config, waitCh: make(chan struct{}), startCh: make(chan struct{}, 1), unblockCh: make(chan struct{}), - restartCh: make(chan *structs.TaskEvent), + restartCh: make(chan *taskRestartEvent), signalCh: make(chan SignalEvent), } @@ -772,7 +789,8 @@ OUTER: return } case structs.VaultChangeModeRestart: - r.Restart("vault", "new Vault token acquired") + const noFailure = false + r.Restart("vault", "new Vault token acquired", noFailure) case structs.VaultChangeModeNoop: fallthrough default: @@ -1137,7 +1155,7 @@ func (r *TaskRunner) run() { res := r.handle.Signal(se.s) se.result <- res - case event := <-r.restartCh: + case restartEvent := <-r.restartCh: r.runningLock.Lock() running := r.running r.runningLock.Unlock() @@ -1147,8 +1165,8 @@ func (r *TaskRunner) run() { continue } - r.logger.Printf("[DEBUG] client: restarting %s: %v", common, event.RestartReason) - r.setState(structs.TaskStateRunning, event, false) + r.logger.Printf("[DEBUG] client: restarting %s: %v", common, restartEvent.taskEvent.RestartReason) + r.setState(structs.TaskStateRunning, restartEvent.taskEvent, false) r.killTask(nil) close(stopCollection) @@ -1157,9 +1175,7 @@ func (r *TaskRunner) run() { <-handleWaitCh } - // Since the restart isn't from a failure, restart immediately - // and don't count against the restart policy - r.restartTracker.SetRestartTriggered() + r.restartTracker.SetRestartTriggered(restartEvent.failure) break WAIT case <-r.destroyCh: @@ -1439,7 +1455,7 @@ func (r *TaskRunner) registerServices(d driver.Driver, h driver.DriverHandle, n exec = h } interpolatedTask := interpolateServices(r.envBuilder.Build(), r.task) - return r.consul.RegisterTask(r.alloc.ID, interpolatedTask, exec, n) + return r.consul.RegisterTask(r.alloc.ID, interpolatedTask, r, exec, n) } // interpolateServices interpolates tags in a service and checks with values from the @@ -1584,6 +1600,7 @@ func (r *TaskRunner) handleUpdate(update *structs.Allocation) error { for _, t := range tg.Tasks { if t.Name == r.task.Name { updatedTask = t.Copy() + break } } if updatedTask == nil { @@ -1641,7 +1658,7 @@ func (r *TaskRunner) updateServices(d driver.Driver, h driver.ScriptExecutor, ol r.driverNetLock.Lock() net := r.driverNet.Copy() r.driverNetLock.Unlock() - return r.consul.UpdateTask(r.alloc.ID, oldInterpolatedTask, newInterpolatedTask, exec, net) + return r.consul.UpdateTask(r.alloc.ID, oldInterpolatedTask, newInterpolatedTask, r, exec, net) } // handleDestroy kills the task handle. In the case that killing fails, @@ -1669,10 +1686,10 @@ func (r *TaskRunner) handleDestroy(handle driver.DriverHandle) (destroyed bool, return } -// Restart will restart the task -func (r *TaskRunner) Restart(source, reason string) { +// Restart will restart the task. +func (r *TaskRunner) Restart(source, reason string, failure bool) { reasonStr := fmt.Sprintf("%s: %s", source, reason) - event := structs.NewTaskEvent(structs.TaskRestartSignal).SetRestartReason(reasonStr) + event := newTaskRestartEvent(reasonStr, failure) select { case r.restartCh <- event: diff --git a/client/task_runner_test.go b/client/task_runner_test.go index 6894115e337..f532e77df45 100644 --- a/client/task_runner_test.go +++ b/client/task_runner_test.go @@ -21,6 +21,7 @@ import ( "github.com/hashicorp/nomad/client/driver/env" cstructs "github.com/hashicorp/nomad/client/structs" "github.com/hashicorp/nomad/client/vaultclient" + "github.com/hashicorp/nomad/command/agent/consul" "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/testutil" @@ -56,10 +57,21 @@ func (m *MockTaskStateUpdater) Update(name, state string, event *structs.TaskEve } } +// String for debugging purposes. +func (m *MockTaskStateUpdater) String() string { + s := fmt.Sprintf("Updates:\n state=%q\n failed=%t\n events=\n", m.state, m.failed) + for _, e := range m.events { + s += fmt.Sprintf(" %#v\n", e) + } + return s +} + type taskRunnerTestCtx struct { upd *MockTaskStateUpdater tr *TaskRunner allocDir *allocdir.AllocDir + vault *vaultclient.MockVaultClient + consul *mockConsulServiceClient } // Cleanup calls Destroy on the task runner and alloc dir @@ -130,7 +142,13 @@ func testTaskRunnerFromAlloc(t *testing.T, restarts bool, alloc *structs.Allocat if !restarts { tr.restartTracker = noRestartsTracker() } - return &taskRunnerTestCtx{upd, tr, allocDir} + return &taskRunnerTestCtx{ + upd: upd, + tr: tr, + allocDir: allocDir, + vault: vclient, + consul: cclient, + } } // testWaitForTaskToStart waits for the task to or fails the test @@ -657,7 +675,7 @@ func TestTaskRunner_RestartTask(t *testing.T) { // Wait for it to start go func() { testWaitForTaskToStart(t, ctx) - ctx.tr.Restart("test", "restart") + ctx.tr.Restart("test", "restart", false) // Wait for it to restart then kill go func() { @@ -1251,8 +1269,7 @@ func TestTaskRunner_Template_NewVaultToken(t *testing.T) { }) // Error the token renewal - vc := ctx.tr.vaultClient.(*vaultclient.MockVaultClient) - renewalCh, ok := vc.RenewTokens[token] + renewalCh, ok := ctx.vault.RenewTokens[token] if !ok { t.Fatalf("no renewal channel") } @@ -1279,13 +1296,12 @@ func TestTaskRunner_Template_NewVaultToken(t *testing.T) { }) // Check the token was revoked - m := ctx.tr.vaultClient.(*vaultclient.MockVaultClient) testutil.WaitForResult(func() (bool, error) { - if len(m.StoppedTokens) != 1 { - return false, fmt.Errorf("Expected a stopped token: %v", m.StoppedTokens) + if len(ctx.vault.StoppedTokens) != 1 { + return false, fmt.Errorf("Expected a stopped token: %v", ctx.vault.StoppedTokens) } - if a := m.StoppedTokens[0]; a != token { + if a := ctx.vault.StoppedTokens[0]; a != token { return false, fmt.Errorf("got stopped token %q; want %q", a, token) } return true, nil @@ -1317,8 +1333,7 @@ func TestTaskRunner_VaultManager_Restart(t *testing.T) { testWaitForTaskToStart(t, ctx) // Error the token renewal - vc := ctx.tr.vaultClient.(*vaultclient.MockVaultClient) - renewalCh, ok := vc.RenewTokens[ctx.tr.vaultFuture.Get()] + renewalCh, ok := ctx.vault.RenewTokens[ctx.tr.vaultFuture.Get()] if !ok { t.Fatalf("no renewal channel") } @@ -1394,8 +1409,7 @@ func TestTaskRunner_VaultManager_Signal(t *testing.T) { testWaitForTaskToStart(t, ctx) // Error the token renewal - vc := ctx.tr.vaultClient.(*vaultclient.MockVaultClient) - renewalCh, ok := vc.RenewTokens[ctx.tr.vaultFuture.Get()] + renewalCh, ok := ctx.vault.RenewTokens[ctx.tr.vaultFuture.Get()] if !ok { t.Fatalf("no renewal channel") } @@ -1726,20 +1740,19 @@ func TestTaskRunner_ShutdownDelay(t *testing.T) { // Service should get removed quickly; loop until RemoveTask is called found := false - mockConsul := ctx.tr.consul.(*mockConsulServiceClient) deadline := destroyed.Add(task.ShutdownDelay) for time.Now().Before(deadline) { time.Sleep(5 * time.Millisecond) - mockConsul.mu.Lock() - n := len(mockConsul.ops) + ctx.consul.mu.Lock() + n := len(ctx.consul.ops) if n < 2 { - mockConsul.mu.Unlock() + ctx.consul.mu.Unlock() continue } - lastOp := mockConsul.ops[n-1].op - mockConsul.mu.Unlock() + lastOp := ctx.consul.ops[n-1].op + ctx.consul.mu.Unlock() if lastOp == "remove" { found = true @@ -1762,3 +1775,97 @@ func TestTaskRunner_ShutdownDelay(t *testing.T) { t.Fatalf("task exited before shutdown delay") } } + +// TestTaskRunner_CheckWatcher_Restart asserts that when enabled an unhealthy +// Consul check will cause a task to restart following restart policy rules. +func TestTaskRunner_CheckWatcher_Restart(t *testing.T) { + t.Parallel() + + alloc := mock.Alloc() + + // Make the restart policy fail within this test + tg := alloc.Job.TaskGroups[0] + tg.RestartPolicy.Attempts = 2 + tg.RestartPolicy.Interval = 1 * time.Minute + tg.RestartPolicy.Delay = 10 * time.Millisecond + tg.RestartPolicy.Mode = structs.RestartPolicyModeFail + + task := tg.Tasks[0] + task.Driver = "mock_driver" + task.Config = map[string]interface{}{ + "exit_code": "0", + "run_for": "100s", + } + + // Make the task register a check that fails + task.Services[0].Checks[0] = &structs.ServiceCheck{ + Name: "test-restarts", + Type: structs.ServiceCheckTCP, + Interval: 50 * time.Millisecond, + CheckRestart: &structs.CheckRestart{ + Limit: 2, + Grace: 100 * time.Millisecond, + }, + } + + ctx := testTaskRunnerFromAlloc(t, true, alloc) + + // Replace mock Consul ServiceClient, with the real ServiceClient + // backed by a mock consul whose checks are always unhealthy. + consulAgent := consul.NewMockAgent() + consulAgent.SetStatus("critical") + consulClient := consul.NewServiceClient(consulAgent, true, ctx.tr.logger) + go consulClient.Run() + defer consulClient.Shutdown() + + ctx.tr.consul = consulClient + ctx.consul = nil // prevent accidental use of old mock + + ctx.tr.MarkReceived() + go ctx.tr.Run() + defer ctx.Cleanup() + + select { + case <-ctx.tr.WaitCh(): + case <-time.After(time.Duration(testutil.TestMultiplier()*15) * time.Second): + t.Fatalf("timeout") + } + + expected := []string{ + "Received", + "Task Setup", + "Started", + "Restart Signaled", + "Killing", + "Killed", + "Restarting", + "Started", + "Restart Signaled", + "Killing", + "Killed", + "Restarting", + "Started", + "Restart Signaled", + "Killing", + "Killed", + "Not Restarting", + } + + if n := len(ctx.upd.events); n != len(expected) { + t.Fatalf("should have %d ctx.updates found %d: %s", len(expected), n, ctx.upd) + } + + if ctx.upd.state != structs.TaskStateDead { + t.Fatalf("TaskState %v; want %v", ctx.upd.state, structs.TaskStateDead) + } + + if !ctx.upd.failed { + t.Fatalf("expected failed") + } + + for i, actual := range ctx.upd.events { + if actual.Type != expected[i] { + t.Errorf("%.2d - Expected %q but found %q", i, expected[i], actual.Type) + } + } +} diff --git a/client/task_runner_unix_test.go b/client/task_runner_unix_test.go index bed7c956d79..b7c2aa4412b 100644 --- a/client/task_runner_unix_test.go +++ b/client/task_runner_unix_test.go @@ -53,7 +53,7 @@ func TestTaskRunner_RestartSignalTask_NotRunning(t *testing.T) { } // Send a restart - ctx.tr.Restart("test", "don't panic") + ctx.tr.Restart("test", "don't panic", false) if len(ctx.upd.events) != 2 { t.Fatalf("should have 2 ctx.updates: %#v", ctx.upd.events) diff --git a/command/agent/consul/catalog_testing.go b/command/agent/consul/catalog_testing.go index f0dd0326ce0..6b28940f114 100644 --- a/command/agent/consul/catalog_testing.go +++ b/command/agent/consul/catalog_testing.go @@ -1,7 +1,9 @@ package consul import ( + "fmt" "log" + "sync" "github.com/hashicorp/consul/api" ) @@ -25,3 +27,119 @@ func (m *MockCatalog) Service(service, tag string, q *api.QueryOptions) ([]*api. m.logger.Printf("[DEBUG] mock_consul: Service(%q, %q, %#v) -> (nil, nil, nil)", service, tag, q) return nil, nil, nil } + +// MockAgent is a fake in-memory Consul backend for ServiceClient. +type MockAgent struct { + // maps of what services and checks have been registered + services map[string]*api.AgentServiceRegistration + checks map[string]*api.AgentCheckRegistration + mu sync.Mutex + + // when UpdateTTL is called the check ID will have its counter inc'd + checkTTLs map[string]int + + // What check status to return from Checks() + checkStatus string +} + +// NewMockAgent that returns all checks as passing. +func NewMockAgent() *MockAgent { + return &MockAgent{ + services: make(map[string]*api.AgentServiceRegistration), + checks: make(map[string]*api.AgentCheckRegistration), + checkTTLs: make(map[string]int), + checkStatus: api.HealthPassing, + } +} + +// SetStatus that Checks() should return. Returns old status value. +func (c *MockAgent) SetStatus(s string) string { + c.mu.Lock() + old := c.checkStatus + c.checkStatus = s + c.mu.Unlock() + return old +} + +func (c *MockAgent) Services() (map[string]*api.AgentService, error) { + c.mu.Lock() + defer c.mu.Unlock() + + r := make(map[string]*api.AgentService, len(c.services)) + for k, v := range c.services { + r[k] = &api.AgentService{ + ID: v.ID, + Service: v.Name, + Tags: make([]string, len(v.Tags)), + Port: v.Port, + Address: v.Address, + EnableTagOverride: v.EnableTagOverride, + } + copy(r[k].Tags, v.Tags) + } + return r, nil +} + +func (c *MockAgent) Checks() (map[string]*api.AgentCheck, error) { + c.mu.Lock() + defer c.mu.Unlock() + + r := make(map[string]*api.AgentCheck, len(c.checks)) + for k, v := range c.checks { + r[k] = &api.AgentCheck{ + CheckID: v.ID, + Name: v.Name, + Status: c.checkStatus, + Notes: v.Notes, + ServiceID: v.ServiceID, + ServiceName: c.services[v.ServiceID].Name, + } + } + return r, nil +} + +func (c *MockAgent) CheckRegister(check *api.AgentCheckRegistration) error { + c.mu.Lock() + defer c.mu.Unlock() + c.checks[check.ID] = check + + // Be nice and make checks reachable-by-service + scheck := check.AgentServiceCheck + c.services[check.ServiceID].Checks = append(c.services[check.ServiceID].Checks, &scheck) + return nil +} + +func (c *MockAgent) CheckDeregister(checkID string) error { + c.mu.Lock() + defer c.mu.Unlock() + delete(c.checks, checkID) + delete(c.checkTTLs, checkID) + return nil +} + +func (c *MockAgent) ServiceRegister(service *api.AgentServiceRegistration) error { + c.mu.Lock() + defer c.mu.Unlock() + c.services[service.ID] = service + return nil +} + +func (c *MockAgent) ServiceDeregister(serviceID string) error { + c.mu.Lock() + defer c.mu.Unlock() + delete(c.services, serviceID) + return nil +} + +func (c *MockAgent) UpdateTTL(id string, output string, status string) error { + c.mu.Lock() + defer c.mu.Unlock() + check, ok := c.checks[id] + if !ok { + return fmt.Errorf("unknown check id: %q", id) + } + // Flip initial status to passing + check.Status = "passing" + c.checkTTLs[id]++ + return nil +} diff --git a/command/agent/consul/check_watcher.go b/command/agent/consul/check_watcher.go new file mode 100644 index 00000000000..4b0656765d6 --- /dev/null +++ b/command/agent/consul/check_watcher.go @@ -0,0 +1,317 @@ +package consul + +import ( + "context" + "fmt" + "log" + "time" + + "github.com/hashicorp/consul/api" + "github.com/hashicorp/nomad/nomad/structs" +) + +const ( + // defaultPollFreq is the default rate to poll the Consul Checks API + defaultPollFreq = 900 * time.Millisecond +) + +// ChecksAPI is the part of the Consul API the checkWatcher requires. +type ChecksAPI interface { + // Checks returns a list of all checks. + Checks() (map[string]*api.AgentCheck, error) +} + +// TaskRestarter allows the checkWatcher to restart tasks. +type TaskRestarter interface { + Restart(source, reason string, failure bool) +} + +// checkRestart handles restarting a task if a check is unhealthy. +type checkRestart struct { + allocID string + taskName string + checkID string + checkName string + taskKey string // composite of allocID + taskName for uniqueness + + task TaskRestarter + grace time.Duration + interval time.Duration + timeLimit time.Duration + ignoreWarnings bool + + // Mutable fields + + // unhealthyState is the time a check first went unhealthy. Set to the + // zero value if the check passes before timeLimit. + unhealthyState time.Time + + // graceUntil is when the check's grace period expires and unhealthy + // checks should be counted. + graceUntil time.Time + + logger *log.Logger +} + +// apply restart state for check and restart task if necessary. Currrent +// timestamp is passed in so all check updates have the same view of time (and +// to ease testing). +// +// Returns true if a restart was triggered in which case this check should be +// removed (checks are added on task startup). +func (c *checkRestart) apply(now time.Time, status string) bool { + healthy := func() { + if !c.unhealthyState.IsZero() { + c.logger.Printf("[DEBUG] consul.health: alloc %q task %q check %q became healthy; canceling restart", + c.allocID, c.taskName, c.checkName) + c.unhealthyState = time.Time{} + } + } + switch status { + case api.HealthCritical: + case api.HealthWarning: + if c.ignoreWarnings { + // Warnings are ignored, reset state and exit + healthy() + return false + } + default: + // All other statuses are ok, reset state and exit + healthy() + return false + } + + if now.Before(c.graceUntil) { + // In grace period, exit + return false + } + + if c.unhealthyState.IsZero() { + // First failure, set restart deadline + if c.timeLimit != 0 { + c.logger.Printf("[DEBUG] consul.health: alloc %q task %q check %q became unhealthy. Restarting in %s if not healthy", + c.allocID, c.taskName, c.checkName, c.timeLimit) + } + c.unhealthyState = now + } + + // restart timeLimit after start of this check becoming unhealthy + restartAt := c.unhealthyState.Add(c.timeLimit) + + // Must test >= because if limit=1, restartAt == first failure + if now.Equal(restartAt) || now.After(restartAt) { + // hasn't become healthy by deadline, restart! + c.logger.Printf("[DEBUG] consul.health: restarting alloc %q task %q due to unhealthy check %q", c.allocID, c.taskName, c.checkName) + + // Tell TaskRunner to restart due to failure + const failure = true + c.task.Restart("healthcheck", fmt.Sprintf("check %q unhealthy", c.checkName), failure) + return true + } + + return false +} + +// checkWatchUpdates add or remove checks from the watcher +type checkWatchUpdate struct { + checkID string + remove bool + checkRestart *checkRestart +} + +// checkWatcher watches Consul checks and restarts tasks when they're +// unhealthy. +type checkWatcher struct { + consul ChecksAPI + + // pollFreq is how often to poll the checks API and defaults to + // defaultPollFreq + pollFreq time.Duration + + // checkUpdateCh is how watches (and removals) are sent to the main + // watching loop + checkUpdateCh chan checkWatchUpdate + + // done is closed when Run has exited + done chan struct{} + + // lastErr is true if the last Consul call failed. It is used to + // squelch repeated error messages. + lastErr bool + + logger *log.Logger +} + +// newCheckWatcher creates a new checkWatcher but does not call its Run method. +func newCheckWatcher(logger *log.Logger, consul ChecksAPI) *checkWatcher { + return &checkWatcher{ + consul: consul, + pollFreq: defaultPollFreq, + checkUpdateCh: make(chan checkWatchUpdate, 8), + done: make(chan struct{}), + logger: logger, + } +} + +// Run the main Consul checks watching loop to restart tasks when their checks +// fail. Blocks until context is canceled. +func (w *checkWatcher) Run(ctx context.Context) { + defer close(w.done) + + // map of check IDs to their metadata + checks := map[string]*checkRestart{} + + // timer for check polling + checkTimer := time.NewTimer(0) + defer checkTimer.Stop() // ensure timer is never leaked + + stopTimer := func() { + checkTimer.Stop() + select { + case <-checkTimer.C: + default: + } + } + + // disable by default + stopTimer() + + // Main watch loop + for { + // disable polling if there are no checks + if len(checks) == 0 { + stopTimer() + } + + select { + case update := <-w.checkUpdateCh: + if update.remove { + // Remove a check + delete(checks, update.checkID) + continue + } + + // Add/update a check + checks[update.checkID] = update.checkRestart + w.logger.Printf("[DEBUG] consul.health: watching alloc %q task %q check %q", + update.checkRestart.allocID, update.checkRestart.taskName, update.checkRestart.checkName) + + // if first check was added make sure polling is enabled + if len(checks) == 1 { + stopTimer() + checkTimer.Reset(w.pollFreq) + } + + case <-ctx.Done(): + return + + case <-checkTimer.C: + checkTimer.Reset(w.pollFreq) + + // Set "now" as the point in time the following check results represent + now := time.Now() + + results, err := w.consul.Checks() + if err != nil { + if !w.lastErr { + w.lastErr = true + w.logger.Printf("[ERR] consul.health: error retrieving health checks: %q", err) + } + continue + } + + w.lastErr = false + + // Keep track of tasks restarted this period so they + // are only restarted once and all of their checks are + // removed. + restartedTasks := map[string]struct{}{} + + // Loop over watched checks and update their status from results + for cid, check := range checks { + if _, ok := restartedTasks[check.taskKey]; ok { + // Check for this task already restarted; remove and skip check + delete(checks, cid) + continue + } + + result, ok := results[cid] + if !ok { + // Only warn if outside grace period to avoid races with check registration + if now.After(check.graceUntil) { + w.logger.Printf("[WARN] consul.health: watched check %q (%s) not found in Consul", check.checkName, cid) + } + continue + } + + restarted := check.apply(now, result.Status) + if restarted { + // Checks are registered+watched on + // startup, so it's safe to remove them + // whenever they're restarted + delete(checks, cid) + + restartedTasks[check.taskKey] = struct{}{} + } + } + + // Ensure even passing checks for restartedTasks are removed + if len(restartedTasks) > 0 { + for cid, check := range checks { + if _, ok := restartedTasks[check.taskKey]; ok { + delete(checks, cid) + } + } + } + } + } +} + +// Watch a check and restart its task if unhealthy. +func (w *checkWatcher) Watch(allocID, taskName, checkID string, check *structs.ServiceCheck, restarter TaskRestarter) { + if !check.TriggersRestarts() { + // Not watched, noop + return + } + + c := &checkRestart{ + allocID: allocID, + taskName: taskName, + checkID: checkID, + checkName: check.Name, + taskKey: fmt.Sprintf("%s%s", allocID, taskName), // unique task ID + task: restarter, + interval: check.Interval, + grace: check.CheckRestart.Grace, + graceUntil: time.Now().Add(check.CheckRestart.Grace), + timeLimit: check.Interval * time.Duration(check.CheckRestart.Limit-1), + ignoreWarnings: check.CheckRestart.IgnoreWarnings, + logger: w.logger, + } + + update := checkWatchUpdate{ + checkID: checkID, + checkRestart: c, + } + + select { + case w.checkUpdateCh <- update: + // sent watch + case <-w.done: + // exited; nothing to do + } +} + +// Unwatch a check. +func (w *checkWatcher) Unwatch(cid string) { + c := checkWatchUpdate{ + checkID: cid, + remove: true, + } + select { + case w.checkUpdateCh <- c: + // sent remove watch + case <-w.done: + // exited; nothing to do + } +} diff --git a/command/agent/consul/check_watcher_test.go b/command/agent/consul/check_watcher_test.go new file mode 100644 index 00000000000..ccc06b5e6fa --- /dev/null +++ b/command/agent/consul/check_watcher_test.go @@ -0,0 +1,320 @@ +package consul + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/hashicorp/consul/api" + "github.com/hashicorp/nomad/nomad/structs" +) + +// checkRestartRecord is used by a testFakeCtx to record when restarts occur +// due to a watched check. +type checkRestartRecord struct { + timestamp time.Time + source string + reason string + failure bool +} + +// fakeCheckRestarter is a test implementation of TaskRestarter. +type fakeCheckRestarter struct { + // restarts is a slice of all of the restarts triggered by the checkWatcher + restarts []checkRestartRecord + + // need the checkWatcher to re-Watch restarted tasks like TaskRunner + watcher *checkWatcher + + // check to re-Watch on restarts + check *structs.ServiceCheck + allocID string + taskName string + checkName string +} + +// newFakeCheckRestart creates a new TaskRestarter. It needs all of the +// parameters checkWatcher.Watch expects. +func newFakeCheckRestarter(w *checkWatcher, allocID, taskName, checkName string, c *structs.ServiceCheck) *fakeCheckRestarter { + return &fakeCheckRestarter{ + watcher: w, + check: c, + allocID: allocID, + taskName: taskName, + checkName: checkName, + } +} + +// Restart implements part of the TaskRestarter interface needed for check +// watching and is normally fulfilled by a TaskRunner. +// +// Restarts are recorded in the []restarts field and re-Watch the check. +func (c *fakeCheckRestarter) Restart(source, reason string, failure bool) { + c.restarts = append(c.restarts, checkRestartRecord{time.Now(), source, reason, failure}) + + // Re-Watch the check just like TaskRunner + c.watcher.Watch(c.allocID, c.taskName, c.checkName, c.check, c) +} + +// String for debugging +func (c *fakeCheckRestarter) String() string { + s := fmt.Sprintf("%s %s %s restarts:\n", c.allocID, c.taskName, c.checkName) + for _, r := range c.restarts { + s += fmt.Sprintf("%s - %s: %s (failure: %t)\n", r.timestamp, r.source, r.reason, r.failure) + } + return s +} + +// checkResponse is a response returned by the fakeChecksAPI after the given +// time. +type checkResponse struct { + at time.Time + id string + status string +} + +// fakeChecksAPI implements the Checks() method for testing Consul. +type fakeChecksAPI struct { + // responses is a map of check ids to their status at a particular + // time. checkResponses must be in chronological order. + responses map[string][]checkResponse +} + +func newFakeChecksAPI() *fakeChecksAPI { + return &fakeChecksAPI{responses: make(map[string][]checkResponse)} +} + +// add a new check status to Consul at the given time. +func (c *fakeChecksAPI) add(id, status string, at time.Time) { + c.responses[id] = append(c.responses[id], checkResponse{at, id, status}) +} + +func (c *fakeChecksAPI) Checks() (map[string]*api.AgentCheck, error) { + now := time.Now() + result := make(map[string]*api.AgentCheck, len(c.responses)) + + // Use the latest response for each check + for k, vs := range c.responses { + for _, v := range vs { + if v.at.After(now) { + break + } + result[k] = &api.AgentCheck{ + CheckID: k, + Name: k, + Status: v.status, + } + } + } + + return result, nil +} + +// testWatcherSetup sets up a fakeChecksAPI and a real checkWatcher with a test +// logger and faster poll frequency. +func testWatcherSetup() (*fakeChecksAPI, *checkWatcher) { + fakeAPI := newFakeChecksAPI() + cw := newCheckWatcher(testLogger(), fakeAPI) + cw.pollFreq = 10 * time.Millisecond + return fakeAPI, cw +} + +func testCheck() *structs.ServiceCheck { + return &structs.ServiceCheck{ + Name: "testcheck", + Interval: 100 * time.Millisecond, + Timeout: 100 * time.Millisecond, + CheckRestart: &structs.CheckRestart{ + Limit: 3, + Grace: 100 * time.Millisecond, + IgnoreWarnings: false, + }, + } +} + +// TestCheckWatcher_Skip asserts unwatched checks are ignored. +func TestCheckWatcher_Skip(t *testing.T) { + t.Parallel() + + // Create a check with restarting disabled + check := testCheck() + check.CheckRestart = nil + + cw := newCheckWatcher(testLogger(), newFakeChecksAPI()) + restarter1 := newFakeCheckRestarter(cw, "testalloc1", "testtask1", "testcheck1", check) + cw.Watch("testalloc1", "testtask1", "testcheck1", check, restarter1) + + // Check should have been dropped as it's not watched + if n := len(cw.checkUpdateCh); n != 0 { + t.Fatalf("expected 0 checks to be enqueued for watching but found %d", n) + } +} + +// TestCheckWatcher_Healthy asserts healthy tasks are not restarted. +func TestCheckWatcher_Healthy(t *testing.T) { + t.Parallel() + + fakeAPI, cw := testWatcherSetup() + + check1 := testCheck() + restarter1 := newFakeCheckRestarter(cw, "testalloc1", "testtask1", "testcheck1", check1) + cw.Watch("testalloc1", "testtask1", "testcheck1", check1, restarter1) + + check2 := testCheck() + check2.CheckRestart.Limit = 1 + check2.CheckRestart.Grace = 0 + restarter2 := newFakeCheckRestarter(cw, "testalloc2", "testtask2", "testcheck2", check2) + cw.Watch("testalloc2", "testtask2", "testcheck2", check2, restarter2) + + // Make both checks healthy from the beginning + fakeAPI.add("testcheck1", "passing", time.Time{}) + fakeAPI.add("testcheck2", "passing", time.Time{}) + + // Run + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + cw.Run(ctx) + + // Ensure restart was never called + if n := len(restarter1.restarts); n > 0 { + t.Errorf("expected check 1 to not be restarted but found %d:\n%s", n, restarter1) + } + if n := len(restarter2.restarts); n > 0 { + t.Errorf("expected check 2 to not be restarted but found %d:\n%s", n, restarter2) + } +} + +// TestCheckWatcher_HealthyWarning asserts checks in warning with +// ignore_warnings=true do not restart tasks. +func TestCheckWatcher_HealthyWarning(t *testing.T) { + t.Parallel() + + fakeAPI, cw := testWatcherSetup() + + check1 := testCheck() + check1.CheckRestart.Limit = 1 + check1.CheckRestart.Grace = 0 + check1.CheckRestart.IgnoreWarnings = true + restarter1 := newFakeCheckRestarter(cw, "testalloc1", "testtask1", "testcheck1", check1) + cw.Watch("testalloc1", "testtask1", "testcheck1", check1, restarter1) + + // Check is always in warning but that's ok + fakeAPI.add("testcheck1", "warning", time.Time{}) + + // Run + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel() + cw.Run(ctx) + + // Ensure restart was never called on check 1 + if n := len(restarter1.restarts); n > 0 { + t.Errorf("expected check 1 to not be restarted but found %d", n) + } +} + +// TestCheckWatcher_Flapping asserts checks that flap from healthy to unhealthy +// before the unhealthy limit is reached do not restart tasks. +func TestCheckWatcher_Flapping(t *testing.T) { + t.Parallel() + + fakeAPI, cw := testWatcherSetup() + + check1 := testCheck() + check1.CheckRestart.Grace = 0 + restarter1 := newFakeCheckRestarter(cw, "testalloc1", "testtask1", "testcheck1", check1) + cw.Watch("testalloc1", "testtask1", "testcheck1", check1, restarter1) + + // Check flaps and is never failing for the full 200ms needed to restart + now := time.Now() + fakeAPI.add("testcheck1", "passing", now) + fakeAPI.add("testcheck1", "critical", now.Add(100*time.Millisecond)) + fakeAPI.add("testcheck1", "passing", now.Add(250*time.Millisecond)) + fakeAPI.add("testcheck1", "critical", now.Add(300*time.Millisecond)) + fakeAPI.add("testcheck1", "passing", now.Add(450*time.Millisecond)) + + ctx, cancel := context.WithTimeout(context.Background(), 600*time.Millisecond) + defer cancel() + cw.Run(ctx) + + // Ensure restart was never called on check 1 + if n := len(restarter1.restarts); n > 0 { + t.Errorf("expected check 1 to not be restarted but found %d\n%s", n, restarter1) + } +} + +// TestCheckWatcher_Unwatch asserts unwatching checks prevents restarts. +func TestCheckWatcher_Unwatch(t *testing.T) { + t.Parallel() + + fakeAPI, cw := testWatcherSetup() + + // Unwatch immediately + check1 := testCheck() + check1.CheckRestart.Limit = 1 + check1.CheckRestart.Grace = 100 * time.Millisecond + restarter1 := newFakeCheckRestarter(cw, "testalloc1", "testtask1", "testcheck1", check1) + cw.Watch("testalloc1", "testtask1", "testcheck1", check1, restarter1) + cw.Unwatch("testcheck1") + + // Always failing + fakeAPI.add("testcheck1", "critical", time.Time{}) + + ctx, cancel := context.WithTimeout(context.Background(), 300*time.Millisecond) + defer cancel() + cw.Run(ctx) + + // Ensure restart was never called on check 1 + if n := len(restarter1.restarts); n > 0 { + t.Errorf("expected check 1 to not be restarted but found %d\n%s", n, restarter1) + } +} + +// TestCheckWatcher_MultipleChecks asserts that when there are multiple checks +// for a single task, all checks should be removed when any of them restart the +// task to avoid multiple restarts. +func TestCheckWatcher_MultipleChecks(t *testing.T) { + t.Parallel() + + fakeAPI, cw := testWatcherSetup() + + check1 := testCheck() + check1.CheckRestart.Limit = 1 + restarter1 := newFakeCheckRestarter(cw, "testalloc1", "testtask1", "testcheck1", check1) + cw.Watch("testalloc1", "testtask1", "testcheck1", check1, restarter1) + + check2 := testCheck() + check2.CheckRestart.Limit = 1 + restarter2 := newFakeCheckRestarter(cw, "testalloc1", "testtask1", "testcheck2", check2) + cw.Watch("testalloc1", "testtask1", "testcheck2", check2, restarter2) + + check3 := testCheck() + check3.CheckRestart.Limit = 1 + restarter3 := newFakeCheckRestarter(cw, "testalloc1", "testtask1", "testcheck3", check3) + cw.Watch("testalloc1", "testtask1", "testcheck3", check3, restarter3) + + // check 2 & 3 fail long enough to cause 1 restart, but only 1 should restart + now := time.Now() + fakeAPI.add("testcheck1", "critical", now) + fakeAPI.add("testcheck1", "passing", now.Add(150*time.Millisecond)) + fakeAPI.add("testcheck2", "critical", now) + fakeAPI.add("testcheck2", "passing", now.Add(150*time.Millisecond)) + fakeAPI.add("testcheck3", "passing", time.Time{}) + + // Run + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + cw.Run(ctx) + + // Ensure that restart was only called once on check 1 or 2. Since + // checks are in a map it's random which check triggers the restart + // first. + if n := len(restarter1.restarts) + len(restarter2.restarts); n != 1 { + t.Errorf("expected check 1 & 2 to be restarted 1 time but found %d\ncheck 1:\n%s\ncheck 2:%s", + n, restarter1, restarter2) + } + + if n := len(restarter3.restarts); n != 0 { + t.Errorf("expected check 3 to not be restarted but found %d:\n%s", n, restarter3) + } +} diff --git a/command/agent/consul/client.go b/command/agent/consul/client.go index 8285785fbde..5116a366512 100644 --- a/command/agent/consul/client.go +++ b/command/agent/consul/client.go @@ -1,6 +1,7 @@ package consul import ( + "context" "fmt" "log" "net" @@ -223,6 +224,9 @@ type ServiceClient struct { // seen is 1 if Consul has ever been seen; otherise 0. Accessed with // atomics. seen int32 + + // checkWatcher restarts checks that are unhealthy. + checkWatcher *checkWatcher } // NewServiceClient creates a new Consul ServiceClient from an existing Consul API @@ -245,6 +249,7 @@ func NewServiceClient(consulClient AgentAPI, skipVerifySupport bool, logger *log allocRegistrations: make(map[string]*AllocRegistration), agentServices: make(map[string]struct{}), agentChecks: make(map[string]struct{}), + checkWatcher: newCheckWatcher(logger, consulClient), } } @@ -267,6 +272,12 @@ func (c *ServiceClient) hasSeen() bool { // be called exactly once. func (c *ServiceClient) Run() { defer close(c.exitCh) + + // start checkWatcher + ctx, cancelWatcher := context.WithCancel(context.Background()) + defer cancelWatcher() + go c.checkWatcher.Run(ctx) + retryTimer := time.NewTimer(0) <-retryTimer.C // disabled by default failures := 0 @@ -274,6 +285,7 @@ func (c *ServiceClient) Run() { select { case <-retryTimer.C: case <-c.shutdownCh: + cancelWatcher() case ops := <-c.opCh: c.merge(ops) } @@ -656,7 +668,7 @@ func (c *ServiceClient) checkRegs(ops *operations, allocID, serviceID string, se // Checks will always use the IP from the Task struct (host's IP). // // Actual communication with Consul is done asynchrously (see Run). -func (c *ServiceClient) RegisterTask(allocID string, task *structs.Task, exec driver.ScriptExecutor, net *cstructs.DriverNetwork) error { +func (c *ServiceClient) RegisterTask(allocID string, task *structs.Task, restarter TaskRestarter, exec driver.ScriptExecutor, net *cstructs.DriverNetwork) error { // Fast path numServices := len(task.Services) if numServices == 0 { @@ -679,6 +691,18 @@ func (c *ServiceClient) RegisterTask(allocID string, task *structs.Task, exec dr c.addTaskRegistration(allocID, task.Name, t) c.commit(ops) + + // Start watching checks. Done after service registrations are built + // since an error building them could leak watches. + for _, service := range task.Services { + serviceID := makeTaskServiceID(allocID, task.Name, service) + for _, check := range service.Checks { + if check.TriggersRestarts() { + checkID := makeCheckID(serviceID, check) + c.checkWatcher.Watch(allocID, task.Name, checkID, check, restarter) + } + } + } return nil } @@ -686,7 +710,7 @@ func (c *ServiceClient) RegisterTask(allocID string, task *structs.Task, exec dr // changed. // // DriverNetwork must not change between invocations for the same allocation. -func (c *ServiceClient) UpdateTask(allocID string, existing, newTask *structs.Task, exec driver.ScriptExecutor, net *cstructs.DriverNetwork) error { +func (c *ServiceClient) UpdateTask(allocID string, existing, newTask *structs.Task, restarter TaskRestarter, exec driver.ScriptExecutor, net *cstructs.DriverNetwork) error { ops := &operations{} t := new(TaskRegistration) @@ -709,7 +733,13 @@ func (c *ServiceClient) UpdateTask(allocID string, existing, newTask *structs.Ta // Existing service entry removed ops.deregServices = append(ops.deregServices, existingID) for _, check := range existingSvc.Checks { - ops.deregChecks = append(ops.deregChecks, makeCheckID(existingID, check)) + cid := makeCheckID(existingID, check) + ops.deregChecks = append(ops.deregChecks, cid) + + // Unwatch watched checks + if check.TriggersRestarts() { + c.checkWatcher.Unwatch(cid) + } } continue } @@ -730,9 +760,9 @@ func (c *ServiceClient) UpdateTask(allocID string, existing, newTask *structs.Ta } // Check to see what checks were updated - existingChecks := make(map[string]struct{}, len(existingSvc.Checks)) + existingChecks := make(map[string]*structs.ServiceCheck, len(existingSvc.Checks)) for _, check := range existingSvc.Checks { - existingChecks[makeCheckID(existingID, check)] = struct{}{} + existingChecks[makeCheckID(existingID, check)] = check } // Register new checks @@ -748,15 +778,28 @@ func (c *ServiceClient) UpdateTask(allocID string, existing, newTask *structs.Ta if err != nil { return err } + for _, checkID := range newCheckIDs { sreg.checkIDs[checkID] = struct{}{} + } + + } + + // Update all watched checks as CheckRestart fields aren't part of ID + if check.TriggersRestarts() { + c.checkWatcher.Watch(allocID, newTask.Name, checkID, check, restarter) } } // Remove existing checks not in updated service - for cid := range existingChecks { + for cid, check := range existingChecks { ops.deregChecks = append(ops.deregChecks, cid) + + // Unwatch checks + if check.TriggersRestarts() { + c.checkWatcher.Unwatch(cid) + } } } @@ -774,6 +817,18 @@ func (c *ServiceClient) UpdateTask(allocID string, existing, newTask *structs.Ta c.addTaskRegistration(allocID, newTask.Name, t) c.commit(ops) + + // Start watching checks. Done after service registrations are built + // since an error building them could leak watches. + for _, service := range newIDs { + serviceID := makeTaskServiceID(allocID, newTask.Name, service) + for _, check := range service.Checks { + if check.TriggersRestarts() { + checkID := makeCheckID(serviceID, check) + c.checkWatcher.Watch(allocID, newTask.Name, checkID, check, restarter) + } + } + } return nil } @@ -788,7 +843,12 @@ func (c *ServiceClient) RemoveTask(allocID string, task *structs.Task) { ops.deregServices = append(ops.deregServices, id) for _, check := range service.Checks { - ops.deregChecks = append(ops.deregChecks, makeCheckID(id, check)) + cid := makeCheckID(id, check) + ops.deregChecks = append(ops.deregChecks, cid) + + if check.TriggersRestarts() { + c.checkWatcher.Unwatch(cid) + } } } diff --git a/command/agent/consul/unit_test.go b/command/agent/consul/unit_test.go index 2a83d2989e9..8bd3e08a5d8 100644 --- a/command/agent/consul/unit_test.go +++ b/command/agent/consul/unit_test.go @@ -8,7 +8,7 @@ import ( "os" "reflect" "strings" - "sync" + "sync/atomic" "testing" "time" @@ -54,12 +54,23 @@ func testTask() *structs.Task { } } +// restartRecorder is a minimal TaskRestarter implementation that simply +// counts how many restarts were triggered. +type restartRecorder struct { + restarts int64 +} + +func (r *restartRecorder) Restart(source, reason string, failure bool) { + atomic.AddInt64(&r.restarts, 1) +} + // testFakeCtx contains a fake Consul AgentAPI and implements the Exec // interface to allow testing without running Consul. type testFakeCtx struct { ServiceClient *ServiceClient - FakeConsul *fakeConsul + FakeConsul *MockAgent Task *structs.Task + Restarter *restartRecorder // Ticked whenever a script is called execs chan int @@ -99,126 +110,21 @@ func (t *testFakeCtx) syncOnce() error { // setupFake creates a testFakeCtx with a ServiceClient backed by a fakeConsul. // A test Task is also provided. func setupFake() *testFakeCtx { - fc := newFakeConsul() + fc := NewMockAgent() return &testFakeCtx{ ServiceClient: NewServiceClient(fc, true, testLogger()), FakeConsul: fc, Task: testTask(), + Restarter: &restartRecorder{}, execs: make(chan int, 100), } } -// fakeConsul is a fake in-memory Consul backend for ServiceClient. -type fakeConsul struct { - // maps of what services and checks have been registered - services map[string]*api.AgentServiceRegistration - checks map[string]*api.AgentCheckRegistration - mu sync.Mutex - - // when UpdateTTL is called the check ID will have its counter inc'd - checkTTLs map[string]int - - // What check status to return from Checks() - checkStatus string -} - -func newFakeConsul() *fakeConsul { - return &fakeConsul{ - services: make(map[string]*api.AgentServiceRegistration), - checks: make(map[string]*api.AgentCheckRegistration), - checkTTLs: make(map[string]int), - checkStatus: api.HealthPassing, - } -} - -func (c *fakeConsul) Services() (map[string]*api.AgentService, error) { - c.mu.Lock() - defer c.mu.Unlock() - - r := make(map[string]*api.AgentService, len(c.services)) - for k, v := range c.services { - r[k] = &api.AgentService{ - ID: v.ID, - Service: v.Name, - Tags: make([]string, len(v.Tags)), - Port: v.Port, - Address: v.Address, - EnableTagOverride: v.EnableTagOverride, - } - copy(r[k].Tags, v.Tags) - } - return r, nil -} - -func (c *fakeConsul) Checks() (map[string]*api.AgentCheck, error) { - c.mu.Lock() - defer c.mu.Unlock() - - r := make(map[string]*api.AgentCheck, len(c.checks)) - for k, v := range c.checks { - r[k] = &api.AgentCheck{ - CheckID: v.ID, - Name: v.Name, - Status: c.checkStatus, - Notes: v.Notes, - ServiceID: v.ServiceID, - ServiceName: c.services[v.ServiceID].Name, - } - } - return r, nil -} - -func (c *fakeConsul) CheckRegister(check *api.AgentCheckRegistration) error { - c.mu.Lock() - defer c.mu.Unlock() - c.checks[check.ID] = check - - // Be nice and make checks reachable-by-service - scheck := check.AgentServiceCheck - c.services[check.ServiceID].Checks = append(c.services[check.ServiceID].Checks, &scheck) - return nil -} - -func (c *fakeConsul) CheckDeregister(checkID string) error { - c.mu.Lock() - defer c.mu.Unlock() - delete(c.checks, checkID) - delete(c.checkTTLs, checkID) - return nil -} - -func (c *fakeConsul) ServiceRegister(service *api.AgentServiceRegistration) error { - c.mu.Lock() - defer c.mu.Unlock() - c.services[service.ID] = service - return nil -} - -func (c *fakeConsul) ServiceDeregister(serviceID string) error { - c.mu.Lock() - defer c.mu.Unlock() - delete(c.services, serviceID) - return nil -} - -func (c *fakeConsul) UpdateTTL(id string, output string, status string) error { - c.mu.Lock() - defer c.mu.Unlock() - check, ok := c.checks[id] - if !ok { - return fmt.Errorf("unknown check id: %q", id) - } - // Flip initial status to passing - check.Status = "passing" - c.checkTTLs[id]++ - return nil -} - func TestConsul_ChangeTags(t *testing.T) { ctx := setupFake() allocID := "allocid" - if err := ctx.ServiceClient.RegisterTask(allocID, ctx.Task, nil, nil); err != nil { + if err := ctx.ServiceClient.RegisterTask(allocID, ctx.Task, ctx.Restarter, nil, nil); err != nil { t.Fatalf("unexpected error registering task: %v", err) } @@ -260,7 +166,7 @@ func TestConsul_ChangeTags(t *testing.T) { origTask := ctx.Task ctx.Task = testTask() ctx.Task.Services[0].Tags[0] = "newtag" - if err := ctx.ServiceClient.UpdateTask("allocid", origTask, ctx.Task, nil, nil); err != nil { + if err := ctx.ServiceClient.UpdateTask("allocid", origTask, ctx.Task, nil, nil, nil); err != nil { t.Fatalf("unexpected error registering task: %v", err) } if err := ctx.syncOnce(); err != nil { @@ -342,7 +248,7 @@ func TestConsul_ChangePorts(t *testing.T) { }, } - if err := ctx.ServiceClient.RegisterTask("allocid", ctx.Task, ctx, nil); err != nil { + if err := ctx.ServiceClient.RegisterTask("allocid", ctx.Task, ctx.Restarter, ctx, nil); err != nil { t.Fatalf("unexpected error registering task: %v", err) } @@ -430,7 +336,7 @@ func TestConsul_ChangePorts(t *testing.T) { // Removed PortLabel; should default to service's (y) }, } - if err := ctx.ServiceClient.UpdateTask("allocid", origTask, ctx.Task, ctx, nil); err != nil { + if err := ctx.ServiceClient.UpdateTask("allocid", origTask, ctx.Task, nil, ctx, nil); err != nil { t.Fatalf("unexpected error registering task: %v", err) } if err := ctx.syncOnce(); err != nil { @@ -505,11 +411,14 @@ func TestConsul_ChangeChecks(t *testing.T) { Interval: time.Second, Timeout: time.Second, PortLabel: "x", + CheckRestart: &structs.CheckRestart{ + Limit: 3, + }, }, } allocID := "allocid" - if err := ctx.ServiceClient.RegisterTask(allocID, ctx.Task, ctx, nil); err != nil { + if err := ctx.ServiceClient.RegisterTask(allocID, ctx.Task, ctx.Restarter, ctx, nil); err != nil { t.Fatalf("unexpected error registering task: %v", err) } @@ -521,6 +430,13 @@ func TestConsul_ChangeChecks(t *testing.T) { t.Fatalf("expected 1 service but found %d:\n%#v", n, ctx.FakeConsul.services) } + // Assert a check restart watch update was enqueued and clear it + if n := len(ctx.ServiceClient.checkWatcher.checkUpdateCh); n != 1 { + t.Fatalf("expected 1 check restart update but found %d", n) + } + upd := <-ctx.ServiceClient.checkWatcher.checkUpdateCh + c1ID := upd.checkID + // Query the allocs registrations and then again when we update. The IDs // should change reg1, err := ctx.ServiceClient.AllocRegistrations(allocID) @@ -566,6 +482,9 @@ func TestConsul_ChangeChecks(t *testing.T) { Interval: 2 * time.Second, Timeout: time.Second, PortLabel: "x", + CheckRestart: &structs.CheckRestart{ + Limit: 3, + }, }, { Name: "c2", @@ -576,9 +495,29 @@ func TestConsul_ChangeChecks(t *testing.T) { PortLabel: "x", }, } - if err := ctx.ServiceClient.UpdateTask("allocid", origTask, ctx.Task, ctx, nil); err != nil { + if err := ctx.ServiceClient.UpdateTask("allocid", origTask, ctx.Task, nil, ctx, nil); err != nil { t.Fatalf("unexpected error registering task: %v", err) } + + // Assert 2 check restart watch updates was enqueued + if n := len(ctx.ServiceClient.checkWatcher.checkUpdateCh); n != 2 { + t.Fatalf("expected 2 check restart updates but found %d", n) + } + + // First the new watch + upd = <-ctx.ServiceClient.checkWatcher.checkUpdateCh + if upd.checkID == c1ID || upd.remove { + t.Fatalf("expected check watch update to be an add of checkID=%q but found remove=%t checkID=%q", + c1ID, upd.remove, upd.checkID) + } + + // Then remove the old watch + upd = <-ctx.ServiceClient.checkWatcher.checkUpdateCh + if upd.checkID != c1ID || !upd.remove { + t.Fatalf("expected check watch update to be a removal of checkID=%q but found remove=%t checkID=%q", + c1ID, upd.remove, upd.checkID) + } + if err := ctx.syncOnce(); err != nil { t.Fatalf("unexpected error syncing task: %v", err) } @@ -601,6 +540,9 @@ func TestConsul_ChangeChecks(t *testing.T) { if expected := fmt.Sprintf(":%d", xPort); v.TCP != expected { t.Errorf("expected Port x=%v but found: %v", expected, v.TCP) } + + // update id + c1ID = k case "c2": if expected := fmt.Sprintf("http://:%d/", xPort); v.HTTP != expected { t.Errorf("expected Port x=%v but found: %v", expected, v.HTTP) @@ -644,13 +586,73 @@ func TestConsul_ChangeChecks(t *testing.T) { } } } + + // Alter a CheckRestart and make sure the watcher is updated but nothing else + origTask = ctx.Task.Copy() + ctx.Task.Services[0].Checks = []*structs.ServiceCheck{ + { + Name: "c1", + Type: "tcp", + Interval: 2 * time.Second, + Timeout: time.Second, + PortLabel: "x", + CheckRestart: &structs.CheckRestart{ + Limit: 11, + }, + }, + { + Name: "c2", + Type: "http", + Path: "/", + Interval: time.Second, + Timeout: time.Second, + PortLabel: "x", + }, + } + if err := ctx.ServiceClient.UpdateTask("allocid", origTask, ctx.Task, nil, ctx, nil); err != nil { + t.Fatalf("unexpected error registering task: %v", err) + } + if err := ctx.syncOnce(); err != nil { + t.Fatalf("unexpected error syncing task: %v", err) + } + + if n := len(ctx.FakeConsul.checks); n != 2 { + t.Fatalf("expected 2 check but found %d:\n%#v", n, ctx.FakeConsul.checks) + } + + for k, v := range ctx.FakeConsul.checks { + if v.Name == "c1" { + if k != c1ID { + t.Errorf("expected c1 to still have id %q but found %q", c1ID, k) + } + break + } + } + + // Assert a check restart watch update was enqueued for a removal and an add + if n := len(ctx.ServiceClient.checkWatcher.checkUpdateCh); n != 1 { + t.Fatalf("expected 1 check restart update but found %d", n) + } + <-ctx.ServiceClient.checkWatcher.checkUpdateCh } // TestConsul_RegServices tests basic service registration. func TestConsul_RegServices(t *testing.T) { ctx := setupFake() - if err := ctx.ServiceClient.RegisterTask("allocid", ctx.Task, nil, nil); err != nil { + // Add a check w/restarting + ctx.Task.Services[0].Checks = []*structs.ServiceCheck{ + { + Name: "testcheck", + Type: "tcp", + Interval: 100 * time.Millisecond, + CheckRestart: &structs.CheckRestart{ + Limit: 3, + }, + }, + } + + if err := ctx.ServiceClient.RegisterTask("allocid", ctx.Task, ctx.Restarter, nil, nil); err != nil { t.Fatalf("unexpected error registering task: %v", err) } @@ -661,6 +663,7 @@ func TestConsul_RegServices(t *testing.T) { if n := len(ctx.FakeConsul.services); n != 1 { t.Fatalf("expected 1 service but found %d:\n%#v", n, ctx.FakeConsul.services) } + for _, v := range ctx.FakeConsul.services { if v.Name != ctx.Task.Services[0].Name { t.Errorf("expected Name=%q != %q", ctx.Task.Services[0].Name, v.Name) @@ -673,13 +676,38 @@ func TestConsul_RegServices(t *testing.T) { } } + // Assert the check update is pending + if n := len(ctx.ServiceClient.checkWatcher.checkUpdateCh); n != 1 { + t.Fatalf("expected 1 check restart update but found %d", n) + } + + // Assert the check update is properly formed + checkUpd := <-ctx.ServiceClient.checkWatcher.checkUpdateCh + if checkUpd.checkRestart.allocID != "allocid" { + t.Fatalf("expected check's allocid to be %q but found %q", "allocid", checkUpd.checkRestart.allocID) + } + if expected := 200 * time.Millisecond; checkUpd.checkRestart.timeLimit != expected { + t.Fatalf("expected check's time limit to be %v but found %v", expected, checkUpd.checkRestart.timeLimit) + } + // Make a change which will register a new service ctx.Task.Services[0].Name = "taskname-service2" ctx.Task.Services[0].Tags[0] = "tag3" - if err := ctx.ServiceClient.RegisterTask("allocid", ctx.Task, nil, nil); err != nil { + if err := ctx.ServiceClient.RegisterTask("allocid", ctx.Task, ctx.Restarter, nil, nil); err != nil { t.Fatalf("unpexpected error registering task: %v", err) } + // Assert check update is pending + if n := len(ctx.ServiceClient.checkWatcher.checkUpdateCh); n != 1 { + t.Fatalf("expected 1 check restart update but found %d", n) + } + + // Assert the check update's id has changed + checkUpd2 := <-ctx.ServiceClient.checkWatcher.checkUpdateCh + if checkUpd.checkID == checkUpd2.checkID { + t.Fatalf("expected new check update to have a new ID both both have: %q", checkUpd.checkID) + } + // Make sure changes don't take affect until sync() is called (since // Run() isn't running) if n := len(ctx.FakeConsul.services); n != 1 { @@ -727,6 +755,20 @@ func TestConsul_RegServices(t *testing.T) { t.Errorf("expected original task to survive not %q", v.Name) } } + + // Assert check update is pending + if n := len(ctx.ServiceClient.checkWatcher.checkUpdateCh); n != 1 { + t.Fatalf("expected 1 check restart update but found %d", n) + } + + // Assert the check update's id is correct and that it's a removal + checkUpd3 := <-ctx.ServiceClient.checkWatcher.checkUpdateCh + if checkUpd2.checkID != checkUpd3.checkID { + t.Fatalf("expected checkid %q but found %q", checkUpd2.checkID, checkUpd3.checkID) + } + if !checkUpd3.remove { + t.Fatalf("expected check watch removal update but found: %#v", checkUpd3) + } } // TestConsul_ShutdownOK tests the ok path for the shutdown logic in @@ -750,7 +792,7 @@ func TestConsul_ShutdownOK(t *testing.T) { go ctx.ServiceClient.Run() // Register a task and agent - if err := ctx.ServiceClient.RegisterTask("allocid", ctx.Task, ctx, nil); err != nil { + if err := ctx.ServiceClient.RegisterTask("allocid", ctx.Task, ctx.Restarter, ctx, nil); err != nil { t.Fatalf("unexpected error registering task: %v", err) } @@ -823,7 +865,7 @@ func TestConsul_ShutdownSlow(t *testing.T) { go ctx.ServiceClient.Run() // Register a task and agent - if err := ctx.ServiceClient.RegisterTask("allocid", ctx.Task, ctx, nil); err != nil { + if err := ctx.ServiceClient.RegisterTask("allocid", ctx.Task, ctx.Restarter, ctx, nil); err != nil { t.Fatalf("unexpected error registering task: %v", err) } @@ -894,7 +936,7 @@ func TestConsul_ShutdownBlocked(t *testing.T) { go ctx.ServiceClient.Run() // Register a task and agent - if err := ctx.ServiceClient.RegisterTask("allocid", ctx.Task, ctx, nil); err != nil { + if err := ctx.ServiceClient.RegisterTask("allocid", ctx.Task, ctx.Restarter, ctx, nil); err != nil { t.Fatalf("unexpected error registering task: %v", err) } @@ -951,7 +993,7 @@ func TestConsul_NoTLSSkipVerifySupport(t *testing.T) { }, } - if err := ctx.ServiceClient.RegisterTask("allocid", ctx.Task, nil, nil); err != nil { + if err := ctx.ServiceClient.RegisterTask("allocid", ctx.Task, ctx.Restarter, nil, nil); err != nil { t.Fatalf("unexpected error registering task: %v", err) } @@ -991,7 +1033,7 @@ func TestConsul_CancelScript(t *testing.T) { }, } - if err := ctx.ServiceClient.RegisterTask("allocid", ctx.Task, ctx, nil); err != nil { + if err := ctx.ServiceClient.RegisterTask("allocid", ctx.Task, ctx.Restarter, ctx, nil); err != nil { t.Fatalf("unexpected error registering task: %v", err) } @@ -1028,7 +1070,7 @@ func TestConsul_CancelScript(t *testing.T) { }, } - if err := ctx.ServiceClient.UpdateTask("allocid", origTask, ctx.Task, ctx, nil); err != nil { + if err := ctx.ServiceClient.UpdateTask("allocid", origTask, ctx.Task, ctx.Restarter, ctx, nil); err != nil { t.Fatalf("unexpected error registering task: %v", err) } @@ -1115,7 +1157,7 @@ func TestConsul_DriverNetwork_AutoUse(t *testing.T) { AutoAdvertise: true, } - if err := ctx.ServiceClient.RegisterTask("allocid", ctx.Task, ctx, net); err != nil { + if err := ctx.ServiceClient.RegisterTask("allocid", ctx.Task, ctx.Restarter, ctx, net); err != nil { t.Fatalf("unexpected error registering task: %v", err) } @@ -1218,7 +1260,7 @@ func TestConsul_DriverNetwork_NoAutoUse(t *testing.T) { AutoAdvertise: false, } - if err := ctx.ServiceClient.RegisterTask("allocid", ctx.Task, ctx, net); err != nil { + if err := ctx.ServiceClient.RegisterTask("allocid", ctx.Task, ctx.Restarter, ctx, net); err != nil { t.Fatalf("unexpected error registering task: %v", err) } @@ -1304,7 +1346,7 @@ func TestConsul_DriverNetwork_Change(t *testing.T) { } // Initial service should advertise host port x - if err := ctx.ServiceClient.RegisterTask("allocid", ctx.Task, ctx, net); err != nil { + if err := ctx.ServiceClient.RegisterTask("allocid", ctx.Task, ctx.Restarter, ctx, net); err != nil { t.Fatalf("unexpected error registering task: %v", err) } @@ -1314,7 +1356,7 @@ func TestConsul_DriverNetwork_Change(t *testing.T) { orig := ctx.Task.Copy() ctx.Task.Services[0].AddressMode = structs.AddressModeHost - if err := ctx.ServiceClient.UpdateTask("allocid", orig, ctx.Task, ctx, net); err != nil { + if err := ctx.ServiceClient.UpdateTask("allocid", orig, ctx.Task, ctx.Restarter, ctx, net); err != nil { t.Fatalf("unexpected error updating task: %v", err) } @@ -1324,7 +1366,7 @@ func TestConsul_DriverNetwork_Change(t *testing.T) { orig = ctx.Task.Copy() ctx.Task.Services[0].AddressMode = structs.AddressModeDriver - if err := ctx.ServiceClient.UpdateTask("allocid", orig, ctx.Task, ctx, net); err != nil { + if err := ctx.ServiceClient.UpdateTask("allocid", orig, ctx.Task, ctx.Restarter, ctx, net); err != nil { t.Fatalf("unexpected error updating task: %v", err) } diff --git a/command/agent/job_endpoint.go b/command/agent/job_endpoint.go index 75a77f0fcf7..5fcf6516165 100644 --- a/command/agent/job_endpoint.go +++ b/command/agent/job_endpoint.go @@ -704,6 +704,13 @@ func ApiTaskToStructsTask(apiTask *api.Task, structsTask *structs.Task) { Header: check.Header, Method: check.Method, } + if check.CheckRestart != nil { + structsTask.Services[i].Checks[j].CheckRestart = &structs.CheckRestart{ + Limit: check.CheckRestart.Limit, + Grace: *check.CheckRestart.Grace, + IgnoreWarnings: check.CheckRestart.IgnoreWarnings, + } + } } } } diff --git a/command/agent/job_endpoint_test.go b/command/agent/job_endpoint_test.go index f7ad1d398d7..9cf62f5e6d8 100644 --- a/command/agent/job_endpoint_test.go +++ b/command/agent/job_endpoint_test.go @@ -1216,6 +1216,11 @@ func TestJobs_ApiJobToStructsJob(t *testing.T) { Interval: 4 * time.Second, Timeout: 2 * time.Second, InitialStatus: "ok", + CheckRestart: &api.CheckRestart{ + Limit: 3, + Grace: helper.TimeToPtr(10 * time.Second), + IgnoreWarnings: true, + }, }, }, }, @@ -1406,6 +1411,11 @@ func TestJobs_ApiJobToStructsJob(t *testing.T) { Interval: 4 * time.Second, Timeout: 2 * time.Second, InitialStatus: "ok", + CheckRestart: &structs.CheckRestart{ + Limit: 3, + Grace: 10 * time.Second, + IgnoreWarnings: true, + }, }, }, }, diff --git a/jobspec/parse.go b/jobspec/parse.go index efda8f95f4e..61e5b9968b3 100644 --- a/jobspec/parse.go +++ b/jobspec/parse.go @@ -922,6 +922,7 @@ func parseServices(jobName string, taskGroupName string, task *api.Task, service } delete(m, "check") + delete(m, "check_restart") if err := mapstructure.WeakDecode(m, &service); err != nil { return err @@ -941,6 +942,18 @@ func parseServices(jobName string, taskGroupName string, task *api.Task, service } } + // Filter check_restart + if cro := checkList.Filter("check_restart"); len(cro.Items) > 0 { + if len(cro.Items) > 1 { + return fmt.Errorf("check_restart '%s': cannot have more than 1 check_restart", service.Name) + } + if cr, err := parseCheckRestart(cro.Items[0]); err != nil { + return multierror.Prefix(err, fmt.Sprintf("service: '%s',", service.Name)) + } else { + service.CheckRestart = cr + } + } + task.Services[idx] = &service } @@ -965,6 +978,7 @@ func parseChecks(service *api.Service, checkObjs *ast.ObjectList) error { "tls_skip_verify", "header", "method", + "check_restart", } if err := checkHCLKeys(co.Val, valid); err != nil { return multierror.Prefix(err, "check ->") @@ -1006,6 +1020,8 @@ func parseChecks(service *api.Service, checkObjs *ast.ObjectList) error { delete(cm, "header") } + delete(cm, "check_restart") + dec, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ DecodeHook: mapstructure.StringToTimeDurationHookFunc(), WeaklyTypedInput: true, @@ -1018,12 +1034,63 @@ func parseChecks(service *api.Service, checkObjs *ast.ObjectList) error { return err } + // Filter check_restart + var checkRestartList *ast.ObjectList + if ot, ok := co.Val.(*ast.ObjectType); ok { + checkRestartList = ot.List + } else { + return fmt.Errorf("check_restart '%s': should be an object", check.Name) + } + + if cro := checkRestartList.Filter("check_restart"); len(cro.Items) > 0 { + if len(cro.Items) > 1 { + return fmt.Errorf("check_restart '%s': cannot have more than 1 check_restart", check.Name) + } + if cr, err := parseCheckRestart(cro.Items[0]); err != nil { + return multierror.Prefix(err, fmt.Sprintf("check: '%s',", check.Name)) + } else { + check.CheckRestart = cr + } + } + service.Checks[idx] = check } return nil } +func parseCheckRestart(cro *ast.ObjectItem) (*api.CheckRestart, error) { + valid := []string{ + "limit", + "grace_period", + "ignore_warnings", + } + + if err := checkHCLKeys(cro.Val, valid); err != nil { + return nil, multierror.Prefix(err, "check_restart ->") + } + + var checkRestart api.CheckRestart + var crm map[string]interface{} + if err := hcl.DecodeObject(&crm, cro.Val); err != nil { + return nil, err + } + + dec, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ + DecodeHook: mapstructure.StringToTimeDurationHookFunc(), + WeaklyTypedInput: true, + Result: &checkRestart, + }) + if err != nil { + return nil, err + } + if err := dec.Decode(crm); err != nil { + return nil, err + } + + return &checkRestart, nil +} + func parseResources(result *api.Resources, list *ast.ObjectList) error { list = list.Elem() if len(list.Items) == 0 { diff --git a/jobspec/parse_test.go b/jobspec/parse_test.go index bf3d2082329..c0cb4d902b9 100644 --- a/jobspec/parse_test.go +++ b/jobspec/parse_test.go @@ -130,6 +130,11 @@ func TestParse(t *testing.T) { PortLabel: "admin", Interval: 10 * time.Second, Timeout: 2 * time.Second, + CheckRestart: &api.CheckRestart{ + Limit: 3, + Grace: helper.TimeToPtr(10 * time.Second), + IgnoreWarnings: true, + }, }, }, }, diff --git a/jobspec/test-fixtures/basic.hcl b/jobspec/test-fixtures/basic.hcl index 843bd9b016a..a9dab9e7262 100644 --- a/jobspec/test-fixtures/basic.hcl +++ b/jobspec/test-fixtures/basic.hcl @@ -95,6 +95,12 @@ job "binstore-storagelocker" { interval = "10s" timeout = "2s" port = "admin" + + check_restart { + limit = 3 + grace_period = "10s" + ignore_warnings = true + } } } diff --git a/nomad/structs/structs.go b/nomad/structs/structs.go index 8175024a32e..c4a66dd9be0 100644 --- a/nomad/structs/structs.go +++ b/nomad/structs/structs.go @@ -2757,6 +2757,41 @@ func (tg *TaskGroup) GoString() string { return fmt.Sprintf("*%#v", *tg) } +// CheckRestart describes if and when a task should be restarted based on +// failing health checks. +type CheckRestart struct { + Limit int // Restart task after this many unhealthy intervals + Grace time.Duration // Grace time to give tasks after starting to get healthy + IgnoreWarnings bool // If true treat checks in `warning` as passing +} + +func (c *CheckRestart) Copy() *CheckRestart { + if c == nil { + return nil + } + + nc := new(CheckRestart) + *nc = *c + return nc +} + +func (c *CheckRestart) Validate() error { + if c == nil { + return nil + } + + var mErr multierror.Error + if c.Limit < 0 { + mErr.Errors = append(mErr.Errors, fmt.Errorf("limit must be greater than or equal to 0 but found %d", c.Limit)) + } + + if c.Grace < 0 { + mErr.Errors = append(mErr.Errors, fmt.Errorf("grace period must be greater than or equal to 0 but found %d", c.Grace)) + } + + return mErr.ErrorOrNil() +} + const ( ServiceCheckHTTP = "http" ServiceCheckTCP = "tcp" @@ -2788,6 +2823,7 @@ type ServiceCheck struct { TLSSkipVerify bool // Skip TLS verification when Protocol=https Method string // HTTP Method to use (GET by default) Header map[string][]string // HTTP Headers for Consul to set when making HTTP checks + CheckRestart *CheckRestart // If and when a task should be restarted based on checks } func (sc *ServiceCheck) Copy() *ServiceCheck { @@ -2798,6 +2834,7 @@ func (sc *ServiceCheck) Copy() *ServiceCheck { *nsc = *sc nsc.Args = helper.CopySliceString(sc.Args) nsc.Header = helper.CopyMapStringSliceString(sc.Header) + nsc.CheckRestart = sc.CheckRestart.Copy() return nsc } @@ -2863,7 +2900,7 @@ func (sc *ServiceCheck) validate() error { } - return nil + return sc.CheckRestart.Validate() } // RequiresPort returns whether the service check requires the task has a port. @@ -2876,6 +2913,12 @@ func (sc *ServiceCheck) RequiresPort() bool { } } +// TriggersRestarts returns true if this check should be watched and trigger a restart +// on failure. +func (sc *ServiceCheck) TriggersRestarts() bool { + return sc.CheckRestart != nil && sc.CheckRestart.Limit > 0 +} + // Hash all ServiceCheck fields and the check's corresponding service ID to // create an identifier. The identifier is not guaranteed to be unique as if // the PortLabel is blank, the Service's PortLabel will be used after Hash is @@ -3013,6 +3056,7 @@ func (s *Service) Validate() error { mErr.Errors = append(mErr.Errors, fmt.Errorf("check %s invalid: %v", c.Name, err)) } } + return mErr.ErrorOrNil() } @@ -3753,7 +3797,7 @@ type TaskEvent struct { } func (te *TaskEvent) GoString() string { - return fmt.Sprintf("%v at %v", te.Type, te.Time) + return fmt.Sprintf("%v - %v", te.Time, te.Type) } // SetMessage sets the message of TaskEvent diff --git a/nomad/structs/structs_test.go b/nomad/structs/structs_test.go index 1ae8b782947..d4858cc7880 100644 --- a/nomad/structs/structs_test.go +++ b/nomad/structs/structs_test.go @@ -1154,6 +1154,24 @@ func TestTask_Validate_Service_Check(t *testing.T) { } } +func TestTask_Validate_Service_Check_CheckRestart(t *testing.T) { + invalidCheckRestart := &CheckRestart{ + Limit: -1, + Grace: -1, + } + + err := invalidCheckRestart.Validate() + assert.NotNil(t, err, "invalidateCheckRestart.Validate()") + assert.Len(t, err.(*multierror.Error).Errors, 2) + + validCheckRestart := &CheckRestart{} + assert.Nil(t, validCheckRestart.Validate()) + + validCheckRestart.Limit = 1 + validCheckRestart.Grace = 1 + assert.Nil(t, validCheckRestart.Validate()) +} + func TestTask_Validate_LogConfig(t *testing.T) { task := &Task{ LogConfig: DefaultLogConfig(), diff --git a/website/source/api/json-jobs.html.md b/website/source/api/json-jobs.html.md index 8851df1d5cd..8e52d3ea43e 100644 --- a/website/source/api/json-jobs.html.md +++ b/website/source/api/json-jobs.html.md @@ -66,7 +66,12 @@ Below is the JSON representation of the job outputed by `$ nomad init`: "Interval": 10000000000, "Timeout": 2000000000, "InitialStatus": "", - "TLSSkipVerify": false + "TLSSkipVerify": false, + "CheckRestart": { + "Limit": 3, + "Grace": "30s", + "IgnoreWarnings": false + } }] }], "Resources": { @@ -377,6 +382,20 @@ The `Task` object supports the following keys: - `TLSSkipVerify`: If true, Consul will not attempt to verify the certificate when performing HTTPS checks. Requires Consul >= 0.7.2. + - `CheckRestart`: `CheckRestart` is an object which enables + restarting of tasks based upon Consul health checks. + + - `Limit`: The number of unhealthy checks allowed before the + service is restarted. Defaults to `0` which disables + health-based restarts. + + - `Grace`: The duration to wait after a task starts or restarts + before counting unhealthy checks count against the limit. + Defaults to "1s". + + - `IgnoreWarnings`: Treat checks that are warning as passing. + Defaults to false which means warnings are considered unhealthy. + - `ShutdownDelay` - Specifies the duration to wait when killing a task between removing it from Consul and sending it a shutdown signal. Ideally services would fail healthchecks once they receive a shutdown signal. Alternatively diff --git a/website/source/docs/job-specification/check_restart.html.md b/website/source/docs/job-specification/check_restart.html.md new file mode 100644 index 00000000000..bac3cec6197 --- /dev/null +++ b/website/source/docs/job-specification/check_restart.html.md @@ -0,0 +1,152 @@ +--- +layout: "docs" +page_title: "check_restart Stanza - Job Specification" +sidebar_current: "docs-job-specification-check_restart" +description: |- + The "check_restart" stanza instructs Nomad when to restart tasks with + unhealthy service checks. +--- + +# `check_restart` Stanza + + + + + + + + + + +
Placement + job -> group -> task -> service -> **check_restart** +
Placement + job -> group -> task -> service -> check -> **check_restart** +
+ +As of Nomad 0.7 the `check_restart` stanza instructs Nomad when to restart +tasks with unhealthy service checks. When a health check in Consul has been +unhealthy for the `limit` specified in a `check_restart` stanza, it is +restarted according to the task group's [`restart` policy][restart_stanza]. The +`check_restart` settings apply to [`check`s][check_stanza], but may also be +placed on [`service`s][service_stanza] to apply to all checks on a service. +If `check_restart` is set on both the check and service, the stanza's are +merged with the check values taking precedence. + +```hcl +job "mysql" { + group "mysqld" { + + restart { + attempts = 3 + delay = "10s" + interval = "10m" + mode = "fail" + } + + task "server" { + service { + tags = ["leader", "mysql"] + + port = "db" + + check { + type = "tcp" + port = "db" + interval = "10s" + timeout = "2s" + } + + check { + type = "script" + name = "check_table" + command = "/usr/local/bin/check_mysql_table_status" + args = ["--verbose"] + interval = "60s" + timeout = "5s" + + check_restart { + limit = 3 + grace = "90s" + ignore_warnings = false + } + } + } + } + } +} +``` + +- `limit` `(int: 0)` - Restart task when a health check has failed `limit` + times. For example 1 causes a restart on the first failure. The default, + `0`, disables health check based restarts. Failures must be consecutive. A + single passing check will reset the count, so flapping services may not be + restarted. + +- `grace` `(string: "1s")` - Duration to wait after a task starts or restarts + before checking its health. + +- `ignore_warnings` `(bool: false)` - By default checks with both `critical` + and `warning` statuses are considered unhealthy. Setting `ignore_warnings = + true` treats a `warning` status like `passing` and will not trigger a restart. + +## Example Behavior + +Using the example `mysql` above would have the following behavior: + +```hcl +check_restart { + # ... + grace = "90s" + # ... +} +``` + +When the `server` task first starts and is registered in Consul, its health +will not be checked for 90 seconds. This gives the server time to startup. + +```hcl +check_restart { + limit = 3 + # ... +} +``` + +After the grace period if the script check fails, it has 180 seconds (`60s +interval * 3 limit`) to pass before a restart is triggered. Once a restart is +triggered the task group's [`restart` policy][restart_stanza] takes control: + +```hcl +restart { + # ... + delay = "10s" + # ... +} +``` + +The [`restart` stanza][restart_stanza] controls the restart behavior of the +task. In this case it will stop the task and then wait 10 seconds before +starting it again. + +Once the task restarts Nomad waits the `grace` period again before starting to +check the task's health. + + +```hcl +restart { + attempts = 3 + # ... + interval = "10m" + mode = "fail" +} +``` + +If the check continues to fail, the task will be restarted up to `attempts` +times within an `interval`. If the `restart` attempts are reached within the +`limit` then the `mode` controls the behavior. In this case the task would fail +and not be restarted again. See the [`restart` stanza][restart_stanza] for +details. + +[check_stanza]: /docs/job-specification/service.html#check-parameters "check stanza" +[restart_stanza]: /docs/job-specification/restart.html "restart stanza" +[service_stanza]: /docs/job-specification/service.html "service stanza" diff --git a/website/source/docs/job-specification/service.html.md b/website/source/docs/job-specification/service.html.md index cef747b0dd6..5d4fc677585 100644 --- a/website/source/docs/job-specification/service.html.md +++ b/website/source/docs/job-specification/service.html.md @@ -47,6 +47,12 @@ job "docs" { args = ["--verbose"] interval = "60s" timeout = "5s" + + check_restart { + limit = 3 + grace = "90s" + ignore_warnings = false + } } } } @@ -111,6 +117,8 @@ scripts. - `args` `(array: [])` - Specifies additional arguments to the `command`. This only applies to script-based health checks. +- `check_restart` - See [`check_restart` stanza][check_restart_stanza]. + - `command` `(string: )` - Specifies the command to run for performing the health check. The script must exit: 0 for passing, 1 for warning, or any other value for a failing health check. This is required for script-based @@ -170,6 +178,7 @@ the header to be set multiple times, once for each value. ```hcl service { + # ... check { type = "http" port = "lb" @@ -315,7 +324,9 @@ service { [qemu driver][qemu] since the Nomad client does not have access to the file system of a task for that driver. +[check_restart_stanza]: /docs/job-specification/check_restart.html "check_restart stanza" [service-discovery]: /docs/service-discovery/index.html "Nomad Service Discovery" [interpolation]: /docs/runtime/interpolation.html "Nomad Runtime Interpolation" [network]: /docs/job-specification/network.html "Nomad network Job Specification" [qemu]: /docs/drivers/qemu.html "Nomad qemu Driver" +[restart_stanza]: /docs/job-specification/restart.html "restart stanza" diff --git a/website/source/layouts/docs.erb b/website/source/layouts/docs.erb index ded57717ca7..39b98d36619 100644 --- a/website/source/layouts/docs.erb +++ b/website/source/layouts/docs.erb @@ -26,6 +26,9 @@ > artifact + > + check_restart + > constraint