Skip to content

Commit

Permalink
Merge pull request #2480 from hashicorp/b-restart-panic
Browse files Browse the repository at this point in the history
Fix panic when restarting non-running task
  • Loading branch information
dadgar authored Mar 24, 2017
2 parents e2968ee + 0021961 commit 68b9735
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 31 deletions.
51 changes: 23 additions & 28 deletions client/task_runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -958,14 +958,34 @@ func (r *TaskRunner) run() {
}

case se := <-r.signalCh:
r.logger.Printf("[DEBUG] client: task being signalled with %v: %s", se.s, se.e.TaskSignalReason)
r.runningLock.Lock()
running := r.running
r.runningLock.Unlock()
common := fmt.Sprintf("signal %v to task %v for alloc %q", se.s, r.task.Name, r.alloc.ID)
if !running {
// Send no error
r.logger.Printf("[DEBUG] client: skipping %s", common)
se.result <- nil
continue
}

r.logger.Printf("[DEBUG] client: sending %s", common)
r.setState(structs.TaskStateRunning, se.e)

res := r.handle.Signal(se.s)
se.result <- res

case event := <-r.restartCh:
r.logger.Printf("[DEBUG] client: task being restarted: %s", event.RestartReason)
r.runningLock.Lock()
running := r.running
r.runningLock.Unlock()
common := fmt.Sprintf("task %v for alloc %q", r.task.Name, r.alloc.ID)
if !running {
r.logger.Printf("[DEBUG] client: skipping restart of %v: task isn't running", common)
continue
}

r.logger.Printf("[DEBUG] client: restarting %s: %v", common, event.RestartReason)
r.setState(structs.TaskStateRunning, event)
r.killTask(nil)

Expand Down Expand Up @@ -1365,23 +1385,9 @@ func (r *TaskRunner) handleDestroy() (destroyed bool, err error) {

// Restart will restart the task
func (r *TaskRunner) Restart(source, reason string) {

reasonStr := fmt.Sprintf("%s: %s", source, reason)
event := structs.NewTaskEvent(structs.TaskRestartSignal).SetRestartReason(reasonStr)

r.logger.Printf("[DEBUG] client: restarting task %v for alloc %q: %v",
r.task.Name, r.alloc.ID, reasonStr)

r.runningLock.Lock()
running := r.running
r.runningLock.Unlock()

// Drop the restart event
if !running {
r.logger.Printf("[DEBUG] client: skipping restart since task isn't running")
return
}

select {
case r.restartCh <- event:
case <-r.waitCh:
Expand All @@ -1394,24 +1400,13 @@ func (r *TaskRunner) Signal(source, reason string, s os.Signal) error {
reasonStr := fmt.Sprintf("%s: %s", source, reason)
event := structs.NewTaskEvent(structs.TaskSignaling).SetTaskSignal(s).SetTaskSignalReason(reasonStr)

r.logger.Printf("[DEBUG] client: sending signal %v to task %v for alloc %q", s, r.task.Name, r.alloc.ID)

r.runningLock.Lock()
running := r.running
r.runningLock.Unlock()

// Drop the restart event
if !running {
r.logger.Printf("[DEBUG] client: skipping signal since task isn't running")
return nil
}

resCh := make(chan error)
se := SignalEvent{
s: s,
e: event,
result: resCh,
}

select {
case r.signalCh <- se:
case <-r.waitCh:
Expand Down
66 changes: 63 additions & 3 deletions client/task_runner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -316,10 +316,10 @@ func TestTaskRunner_Update(t *testing.T) {
return false, fmt.Errorf("Task not copied")
}
if ctx.tr.restartTracker.policy.Mode != newMode {
return false, fmt.Errorf("restart policy not ctx.upd.ted")
return false, fmt.Errorf("restart policy not ctx.updated")
}
if ctx.tr.handle.ID() == oldHandle {
return false, fmt.Errorf("handle not ctx.upd.ted")
return false, fmt.Errorf("handle not ctx.updated")
}
return true, nil
}, func(err error) {
Expand Down Expand Up @@ -645,6 +645,66 @@ func TestTaskRunner_RestartTask(t *testing.T) {
}
}

// This test is just to make sure we are resilient to failures when a restart or
// signal is triggered and the task is not running.
func TestTaskRunner_RestartSignalTask_NotRunning(t *testing.T) {
alloc := mock.Alloc()
task := alloc.Job.TaskGroups[0].Tasks[0]
task.Driver = "mock_driver"
task.Config = map[string]interface{}{
"exit_code": "0",
"run_for": "100s",
}

// Use vault to block the start
task.Vault = &structs.Vault{Policies: []string{"default"}}

ctx := testTaskRunnerFromAlloc(t, true, alloc)
ctx.tr.MarkReceived()
defer ctx.Cleanup()

// Control when we get a Vault token
token := "1234"
waitCh := make(chan struct{})
defer close(waitCh)
handler := func(*structs.Allocation, []string) (map[string]string, error) {
<-waitCh
return map[string]string{task.Name: token}, nil
}
ctx.tr.vaultClient.(*vaultclient.MockVaultClient).DeriveTokenFn = handler
go ctx.tr.Run()

select {
case <-ctx.tr.WaitCh():
t.Fatalf("premature exit")
case <-time.After(1 * time.Second):
}

// Send a signal and restart
if err := ctx.tr.Signal("test", "don't panic", syscall.SIGCHLD); err != nil {
t.Fatalf("Signalling errored: %v", err)
}

// Send a restart
ctx.tr.Restart("test", "don't panic")

if len(ctx.upd.events) != 2 {
t.Fatalf("should have 2 ctx.updates: %#v", ctx.upd.events)
}

if ctx.upd.state != structs.TaskStatePending {
t.Fatalf("TaskState %v; want %v", ctx.upd.state, structs.TaskStatePending)
}

if ctx.upd.events[0].Type != structs.TaskReceived {
t.Fatalf("First Event was %v; want %v", ctx.upd.events[0].Type, structs.TaskReceived)
}

if ctx.upd.events[1].Type != structs.TaskSetup {
t.Fatalf("Second Event was %v; want %v", ctx.upd.events[1].Type, structs.TaskSetup)
}
}

func TestTaskRunner_KillTask(t *testing.T) {
alloc := mock.Alloc()
task := alloc.Job.TaskGroups[0].Tasks[0]
Expand Down Expand Up @@ -1148,7 +1208,7 @@ func TestTaskRunner_Template_NewVaultToken(t *testing.T) {
}

if originalManager == ctx.tr.templateManager {
return false, fmt.Errorf("Template manager not ctx.upd.ted")
return false, fmt.Errorf("Template manager not ctx.updated")
}

return true, nil
Expand Down

0 comments on commit 68b9735

Please sign in to comment.