From 0aa93f87d62855ad3245f5f8a217399f1e450ed4 Mon Sep 17 00:00:00 2001 From: Peter McAtominey Date: Fri, 31 Jul 2020 23:10:53 +0100 Subject: [PATCH] mage: cancel context on SIGINT On receiving an interrupt signal, mage cancels the context allowing the magefile to perform any cleanup before exiting. A second interrupt signal will kill the magefile process without delay. The behaviour for a timeout remains unchanged (context is canclled and the magefile exits). --- mage/main.go | 6 +++ mage/main_test.go | 91 +++++++++++++++++++++++++++++++- mage/template.go | 46 ++++++++++------ mage/testdata/signals/signals.go | 47 +++++++++++++++++ 4 files changed, 172 insertions(+), 18 deletions(-) create mode 100644 mage/testdata/signals/signals.go diff --git a/mage/main.go b/mage/main.go index cccb0870..dd1bfc8b 100644 --- a/mage/main.go +++ b/mage/main.go @@ -11,11 +11,13 @@ import ( "log" "os" "os/exec" + "os/signal" "path/filepath" "regexp" "runtime" "sort" "strings" + "syscall" "text/template" "time" @@ -650,6 +652,10 @@ func RunCompiled(inv Invocation, exePath string, errlog *log.Logger) int { c.Env = append(c.Env, fmt.Sprintf("MAGEFILE_TIMEOUT=%s", inv.Timeout.String())) } debug.Print("running magefile with mage vars:\n", strings.Join(filter(c.Env, "MAGEFILE"), "\n")) + // catch SIGINT to allow magefile to handle them + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT) + defer signal.Stop(sigCh) err := c.Run() if !sh.CmdRan(err) { errlog.Printf("failed to run compiled magefile: %v", err) diff --git a/mage/main_test.go b/mage/main_test.go index 9fd5dbb4..a02fa963 100644 --- a/mage/main_test.go +++ b/mage/main_test.go @@ -20,6 +20,7 @@ import ( "runtime" "strconv" "strings" + "syscall" "testing" "time" @@ -1146,7 +1147,7 @@ func TestCompiledFlags(t *testing.T) { if err == nil { t.Fatalf("expected an error because of timeout") } - got = stdout.String() + got = stderr.String() want = "context deadline exceeded" if strings.Contains(got, want) == false { t.Errorf("got %q, does not contain %q", got, want) @@ -1235,7 +1236,7 @@ func TestCompiledEnvironmentVars(t *testing.T) { if err == nil { t.Fatalf("expected an error because of timeout") } - got = stdout.String() + got = stderr.String() want = "context deadline exceeded" if strings.Contains(got, want) == false { t.Errorf("got %q, does not contain %q", got, want) @@ -1305,6 +1306,92 @@ func TestCompiledVerboseFlag(t *testing.T) { } } +func TestSignals(t *testing.T) { + stderr := &bytes.Buffer{} + stdout := &bytes.Buffer{} + dir := "./testdata/signals" + compileDir, err := ioutil.TempDir(dir, "") + if err != nil { + t.Fatal(err) + } + name := filepath.Join(compileDir, "mage_out") + // The CompileOut directory is relative to the + // invocation directory, so chop off the invocation dir. + outName := "./" + name[len(dir)-1:] + defer os.RemoveAll(compileDir) + inv := Invocation{ + Dir: dir, + Stdout: stdout, + Stderr: stderr, + CompileOut: outName, + } + code := Invoke(inv) + if code != 0 { + t.Errorf("expected to exit with code 0, but got %v, stderr: %s", code, stderr) + } + + run := func(stdout, stderr *bytes.Buffer, filename string, target string, signals ...syscall.Signal) error { + stderr.Reset() + stdout.Reset() + cmd := exec.Command(filename, target) + cmd.Stderr = stderr + cmd.Stdout = stdout + if err := cmd.Start(); err != nil { + return fmt.Errorf("running '%s %s' failed with: %v\nstdout: %s\nstderr: %s", + filename, target, err, stdout, stderr) + } + pid := cmd.Process.Pid + go func() { + time.Sleep(time.Millisecond * 500) + for _, s := range signals { + syscall.Kill(pid, s) + time.Sleep(time.Millisecond * 50) + } + }() + if err := cmd.Wait(); err != nil { + return fmt.Errorf("running '%s %s' failed with: %v\nstdout: %s\nstderr: %s", + filename, target, err, stdout, stderr) + } + return nil + } + + if err := run(stdout, stderr, name, "exitsAfterSighup", syscall.SIGHUP); err != nil { + t.Fatal(err) + } + got := stdout.String() + want := "received sighup\n" + if strings.Contains(got, want) == false { + t.Errorf("got %q, does not contain %q", got, want) + } + + if err := run(stdout, stderr, name, "exitsAfterSigint", syscall.SIGINT); err != nil { + t.Fatal(err) + } + got = stdout.String() + want = "exiting...done\n" + if strings.Contains(got, want) == false { + t.Errorf("got %q, does not contain %q", got, want) + } + + if err := run(stdout, stderr, name, "exitsAfterCancel", syscall.SIGINT); err != nil { + t.Fatal(err) + } + got = stdout.String() + want = "exiting...done\n" + if strings.Contains(got, want) == false { + t.Errorf("got %q, does not contain %q", got, want) + } + + if err := run(stdout, stderr, name, "ignoresSignals", syscall.SIGINT, syscall.SIGINT); err == nil { + t.Fatalf("expected an error because of force kill") + } + got = stderr.String() + want = "Error: target killed\n" + if strings.Contains(got, want) == false { + t.Errorf("got %q, does not contain %q", got, want) + } +} + func TestClean(t *testing.T) { if err := os.RemoveAll(mg.CacheDir()); err != nil { t.Error("error removing cache dir:", err) diff --git a/mage/template.go b/mage/template.go index fbe69719..af7b7007 100644 --- a/mage/template.go +++ b/mage/template.go @@ -14,10 +14,12 @@ import ( "io/ioutil" "log" "os" + "os/signal" "path/filepath" "sort" "strconv" "strings" + "syscall" "text/tabwriter" "time" {{range .Imports}}{{.UniqueName}} "{{.Path}}" @@ -260,17 +262,19 @@ Options: var ctxCancel func() getContext := func() (context.Context, func()) { - if ctx != nil { - return ctx, ctxCancel + if ctx == nil { + ctx, ctxCancel = context.WithCancel(context.Background()) } + return ctx, ctxCancel + } + + getTimeout := func() <-chan time.Time { if args.Timeout != 0 { - ctx, ctxCancel = context.WithTimeout(context.Background(), args.Timeout) - } else { - ctx = context.Background() - ctxCancel = func() {} + return time.After(args.Timeout) } - return ctx, ctxCancel + + return make(chan time.Time) } runTarget := func(fn func(context.Context) error) interface{} { @@ -285,15 +289,25 @@ Options: err := fn(ctx) d <- err }() - select { - case <-ctx.Done(): - cancel() - e := ctx.Err() - fmt.Printf("ctx err: %v\n", e) - return e - case err = <-d: - cancel() - return err + timeoutCh := getTimeout() + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT) + for { + select { + case <-sigCh: + select { + case <-ctx.Done(): + return fmt.Errorf("target killed") + default: + cancel() + } + case <-timeoutCh: + cancel() + return fmt.Errorf("context deadline exceeded") + case err = <-d: + cancel() + return err + } } } // This is necessary in case there aren't any targets, to avoid an unused diff --git a/mage/testdata/signals/signals.go b/mage/testdata/signals/signals.go new file mode 100644 index 00000000..4c58116b --- /dev/null +++ b/mage/testdata/signals/signals.go @@ -0,0 +1,47 @@ +//+build mage + +package main + +import ( + "context" + "fmt" + "os" + "os/signal" + "syscall" + "time" +) + +// Exits after receiving SIGHUP +func ExitsAfterSighup(ctx context.Context) { + sigC := make(chan os.Signal, 1) + signal.Notify(sigC, syscall.SIGHUP) + <-sigC + fmt.Println("received sighup") +} + +// Exits after SIGINT and wait +func ExitsAfterSigint(ctx context.Context) { + sigC := make(chan os.Signal, 1) + signal.Notify(sigC, syscall.SIGINT) + <-sigC + fmt.Printf("exiting...") + time.Sleep(200 * time.Millisecond) + fmt.Println("done") +} + +// Exits after ctx cancel and wait +func ExitsAfterCancel(ctx context.Context) { + <-ctx.Done() + fmt.Printf("exiting...") + time.Sleep(200 * time.Millisecond) + fmt.Println("done") +} + +// Ignores all signals, requires killing +func IgnoresSignals(ctx context.Context) { + sigC := make(chan os.Signal, 1) + signal.Notify(sigC, syscall.SIGINT) + for { + <-sigC + } +}