diff --git a/cli-plugins/plugin/plugin.go b/cli-plugins/plugin/plugin.go index 73059578f983..69db2422416e 100644 --- a/cli-plugins/plugin/plugin.go +++ b/cli-plugins/plugin/plugin.go @@ -1,8 +1,12 @@ package plugin import ( + "context" "encoding/json" + "errors" "fmt" + "io" + "net" "os" "sync" @@ -14,6 +18,11 @@ import ( "github.com/spf13/cobra" ) +// CLIPluginSocketEnvKey is used to pass the plugin being +// executed the abstract socket name it should listen on to know +// when the CLI has exited. +const CLIPluginSocketEnvKey = "DOCKER_CLI_PLUGIN_SOCKET" + // PersistentPreRunE must be called by any plugin command (or // subcommand) which uses the cobra `PersistentPreRun*` hook. Plugins // which do not make use of `PersistentPreRun*` do not need to call @@ -24,14 +33,56 @@ import ( // called. var PersistentPreRunE func(*cobra.Command, []string) error +// closeOnCLISocketClose connects to the socket specified +// by the DOCKER_CLI_PLUGIN_SOCKET env var, if present, and attempts +// to read from it until it receives an EOF, which signals that +// the CLI is going to exit and the plugin should also exit. +func closeOnCLISocketClose(cancel func()) { + socketAddr, ok := os.LookupEnv(CLIPluginSocketEnvKey) + if !ok { + // if a plugin compiled against a more recent version of docker/cli + // is executed by an older CLI binary, ignore missing environment + // variable and behave as usual + return + } + addr, err := net.ResolveUnixAddr("unix", socketAddr) + if err != nil { + return + } + cliCloseConn, err := net.DialUnix("unix", nil, addr) + if err != nil { + return + } + + go func() { + b := make([]byte, 1) + for { + _, err := cliCloseConn.Read(b) + if errors.Is(err, io.EOF) { + cancel() + } + } + }() +} + // RunPlugin executes the specified plugin command func RunPlugin(dockerCli *command.DockerCli, plugin *cobra.Command, meta manager.Metadata) error { tcmd := newPluginCommand(dockerCli, plugin, meta) var persistentPreRunOnce sync.Once - PersistentPreRunE = func(_ *cobra.Command, _ []string) error { + PersistentPreRunE = func(cmd *cobra.Command, _ []string) error { var err error persistentPreRunOnce.Do(func() { + cmdContext := cmd.Context() + // TODO: revisit and make sure this check makes sense + // see: https://github.com/docker/cli/pull/4599#discussion_r1422487271 + if cmdContext == nil { + cmdContext = context.TODO() + } + ctx, cancel := context.WithCancel(cmdContext) + cmd.SetContext(ctx) + closeOnCLISocketClose(cancel) + var opts []command.InitializeOpt if os.Getenv("DOCKER_CLI_PLUGIN_USE_DIAL_STDIO") != "" { opts = append(opts, withPluginClientConn(plugin.Name())) diff --git a/cmd/docker/docker.go b/cmd/docker/docker.go index e55bed358884..e2d8d07f0882 100644 --- a/cmd/docker/docker.go +++ b/cmd/docker/docker.go @@ -2,18 +2,22 @@ package main import ( "fmt" + "net" "os" "os/exec" + "os/signal" "strings" "syscall" "github.com/docker/cli/cli" pluginmanager "github.com/docker/cli/cli-plugins/manager" + "github.com/docker/cli/cli-plugins/plugin" "github.com/docker/cli/cli/command" "github.com/docker/cli/cli/command/commands" cliflags "github.com/docker/cli/cli/flags" "github.com/docker/cli/cli/version" - "github.com/docker/cli/cmd/docker/internal/appcontext" + platformsignals "github.com/docker/cli/cmd/docker/internal/signals" + "github.com/docker/distribution/uuid" "github.com/docker/docker/api/types/versions" "github.com/pkg/errors" "github.com/sirupsen/logrus" @@ -187,6 +191,13 @@ func setValidateArgs(dockerCli command.Cli, cmd *cobra.Command) { }) } +func setupPluginSocket() (*net.UnixListener, error) { + return net.ListenUnix("unix", &net.UnixAddr{ + Name: "@docker_cli_" + uuid.Generate().String(), + Net: "unix", + }) +} + func tryPluginRun(dockerCli command.Cli, cmd *cobra.Command, subcommand string, envs []string) error { plugincmd, err := pluginmanager.PluginRunCommand(dockerCli, subcommand, cmd) if err != nil { @@ -194,9 +205,45 @@ func tryPluginRun(dockerCli command.Cli, cmd *cobra.Command, subcommand string, } plugincmd.Env = append(envs, plugincmd.Env...) + var conn *net.UnixConn + listener, err := setupPluginSocket() + if err == nil { + defer listener.Close() + plugincmd.Env = append(plugincmd.Env, plugin.CLIPluginSocketEnvKey+"="+listener.Addr().String()) + + go func() { + for { + // ignore error here, if we failed to accept a connection, + // conn is nil and we fallback to previous behavior + conn, _ = listener.AcceptUnix() + } + }() + } + + const exitLimit = 3 + + signals := make(chan os.Signal, exitLimit) + signal.Notify(signals, platformsignals.TerminationSignals...) + // signal handling goroutine: listen on signals channel, and if conn is + // non-nil, attempt to close it to let the plugin know to exit. Regardless + // of whether we successfully signal the plugin or not, after 3 SIGINTs, + // we send a SIGKILL to the plugin process and exit go func() { - // override SIGTERM handler so we let the plugin shut down first - <-appcontext.Context().Done() + retries := 0 + for range signals { + if conn != nil { + if err := conn.Close(); err != nil { + _, _ = fmt.Fprintf(dockerCli.Err(), "failed to signal plugin to close: %v\n", err) + } + conn = nil + } + retries++ + if retries >= exitLimit { + _, _ = fmt.Fprintf(dockerCli.Err(), "got %d SIGTERM/SIGINTs, forcefully exiting\n", retries) + _ = plugincmd.Process.Kill() + os.Exit(1) + } + } }() if err := plugincmd.Run(); err != nil { diff --git a/cmd/docker/internal/appcontext/appcontext.go b/cmd/docker/internal/appcontext/appcontext.go deleted file mode 100644 index f41f4b6d7508..000000000000 --- a/cmd/docker/internal/appcontext/appcontext.go +++ /dev/null @@ -1,44 +0,0 @@ -package appcontext - -import ( - "context" - "os" - "os/signal" - "sync" - - "github.com/sirupsen/logrus" -) - -var ( - appContextCache context.Context - appContextOnce sync.Once -) - -// Context returns a static context that reacts to termination signals of the -// running process. Useful in CLI tools. -func Context() context.Context { - appContextOnce.Do(func() { - signals := make(chan os.Signal, 2048) - signal.Notify(signals, terminationSignals...) - - const exitLimit = 3 - retries := 0 - - ctx := context.Background() - ctx, cancel := context.WithCancel(ctx) - appContextCache = ctx - - go func() { - for { - <-signals - cancel() - retries++ - if retries >= exitLimit { - logrus.Errorf("got %d SIGTERM/SIGINTs, forcing shutdown", retries) - os.Exit(1) - } - } - }() - }) - return appContextCache -} diff --git a/cmd/docker/internal/appcontext/appcontext_unix.go b/cmd/docker/internal/appcontext/appcontext_unix.go deleted file mode 100644 index 366edc68b399..000000000000 --- a/cmd/docker/internal/appcontext/appcontext_unix.go +++ /dev/null @@ -1,12 +0,0 @@ -//go:build !windows -// +build !windows - -package appcontext - -import ( - "os" - - "golang.org/x/sys/unix" -) - -var terminationSignals = []os.Signal{unix.SIGTERM, unix.SIGINT} diff --git a/cmd/docker/internal/appcontext/appcontext_windows.go b/cmd/docker/internal/appcontext/appcontext_windows.go deleted file mode 100644 index 0a8bcbe7df2a..000000000000 --- a/cmd/docker/internal/appcontext/appcontext_windows.go +++ /dev/null @@ -1,7 +0,0 @@ -package appcontext - -import ( - "os" -) - -var terminationSignals = []os.Signal{os.Interrupt} diff --git a/cmd/docker/internal/signals/signals_unix.go b/cmd/docker/internal/signals/signals_unix.go new file mode 100644 index 000000000000..c22058a90866 --- /dev/null +++ b/cmd/docker/internal/signals/signals_unix.go @@ -0,0 +1,14 @@ +//go:build unix +// +build unix + +package signals + +import ( + "os" + + "golang.org/x/sys/unix" +) + +// TerminationSignals represents the list of signals we +// want to special-case handle, on this platform. +var TerminationSignals = []os.Signal{unix.SIGTERM, unix.SIGINT} diff --git a/cmd/docker/internal/signals/signals_windows.go b/cmd/docker/internal/signals/signals_windows.go new file mode 100644 index 000000000000..d6c5773f3668 --- /dev/null +++ b/cmd/docker/internal/signals/signals_windows.go @@ -0,0 +1,7 @@ +package signals + +import "os" + +// TerminationSignals represents the list of signals we +// want to special-case handle, on this platform. +var TerminationSignals = []os.Signal{os.Interrupt}