diff --git a/internal/cmd/cmd.go b/internal/cmd/cmd.go index d7228619eb..c11dd3ca7a 100644 --- a/internal/cmd/cmd.go +++ b/internal/cmd/cmd.go @@ -5,6 +5,7 @@ package cmd import ( "bytes" "context" + "errors" "fmt" "io" "strings" @@ -20,6 +21,8 @@ import ( "golang.org/x/sys/windows" ) +var errIOTimeOut = errors.New("timed out waiting for stdio relay") + // CmdProcessRequest stores information on command requests made through this package. type CmdProcessRequest struct { Args []string @@ -136,9 +139,19 @@ func CommandContext(ctx context.Context, host cow.ProcessHost, name string, arg // Start starts a command. The caller must ensure that if Start succeeds, // Wait is eventually called to clean up resources. func (c *Cmd) Start() error { + if c.Host == nil { + return errors.New("empty ProcessHost") + } + + // closed in (*Cmd).Wait; signals command execution is done c.allDoneCh = make(chan struct{}) + var x interface{} if !c.Host.IsOCI() { + if c.Spec == nil { + return errors.New("process spec is required for non-OCI ProcessHost") + } + wpp := &hcsschema.ProcessParameters{ CommandLine: c.Spec.CommandLine, User: c.Spec.User.Username, @@ -211,7 +224,7 @@ func (c *Cmd) Start() error { c.stdinErr.Store(err) } // Notify the process that there is no more input. - if err := p.CloseStdin(context.TODO()); err != nil && c.Log != nil { + if err := p.CloseStdin(context.TODO()); err != nil && !isClosedIOErr(err) && c.Log != nil { c.Log.WithError(err).Warn("failed to close Cmd stdin") } }() @@ -220,8 +233,8 @@ func (c *Cmd) Start() error { if c.Stdout != nil { c.iogrp.Go(func() error { _, err := relayIO(c.Stdout, stdout, c.Log, "stdout") - if err := p.CloseStdout(context.TODO()); err != nil { - c.Log.WithError(err).Warn("failed to close Cmd stdout") + if cErr := p.CloseStdout(context.TODO()); cErr != nil && !isClosedIOErr(cErr) && c.Log != nil { + c.Log.WithError(cErr).Warn("failed to close Cmd stdout") } return err }) @@ -230,8 +243,8 @@ func (c *Cmd) Start() error { if c.Stderr != nil { c.iogrp.Go(func() error { _, err := relayIO(c.Stderr, stderr, c.Log, "stderr") - if err := p.CloseStderr(context.TODO()); err != nil { - c.Log.WithError(err).Warn("failed to close Cmd stderr") + if cErr := p.CloseStderr(context.TODO()); cErr != nil && !isClosedIOErr(cErr) && c.Log != nil { + c.Log.WithError(cErr).Warn("failed to close Cmd stderr") } return err }) @@ -270,9 +283,12 @@ func (c *Cmd) Wait() error { state.exited = true state.code = code } + // Terminate the IO if the copy does not complete in the requested time. + timeoutErrCh := make(chan error) if c.CopyAfterExitTimeout != 0 { go func() { + defer close(timeoutErrCh) t := time.NewTimer(c.CopyAfterExitTimeout) defer t.Stop() select { @@ -280,17 +296,27 @@ func (c *Cmd) Wait() error { case <-t.C: // Close the process to cancel any reads to stdout or stderr. c.Process.Close() + err := errIOTimeOut + // log the timeout, since we may not return it to the caller if c.Log != nil { - c.Log.Warn("timed out waiting for stdio relay") + c.Log.WithField("timeout", c.CopyAfterExitTimeout).Warn(err.Error()) } + timeoutErrCh <- err } }() + } else { + close(timeoutErrCh) } + + // TODO (go1.20): use multierror for these ioErr := c.iogrp.Wait() if ioErr == nil { ioErr, _ = c.stdinErr.Load().(error) } close(c.allDoneCh) + if tErr := <-timeoutErrCh; ioErr == nil { + ioErr = tErr + } c.Process.Close() c.ExitState = state if exitErr != nil { diff --git a/internal/cmd/cmd_test.go b/internal/cmd/cmd_test.go index 00eb86c7b0..109343d9f2 100644 --- a/internal/cmd/cmd_test.go +++ b/internal/cmd/cmd_test.go @@ -213,22 +213,22 @@ func TestCmdStdinBlocked(t *testing.T) { } } -type stuckIoProcessHost struct { +type stuckIOProcessHost struct { cow.ProcessHost } -type stuckIoProcess struct { +type stuckIOProcess struct { cow.Process stdin, pstdout, pstderr *io.PipeWriter pstdin, stdout, stderr *io.PipeReader } -func (h *stuckIoProcessHost) CreateProcess(ctx context.Context, cfg interface{}) (cow.Process, error) { +func (h *stuckIOProcessHost) CreateProcess(ctx context.Context, cfg interface{}) (cow.Process, error) { p, err := h.ProcessHost.CreateProcess(ctx, cfg) if err != nil { return nil, err } - sp := &stuckIoProcess{ + sp := &stuckIOProcess{ Process: p, } sp.pstdin, sp.stdin = io.Pipe() @@ -237,11 +237,11 @@ func (h *stuckIoProcessHost) CreateProcess(ctx context.Context, cfg interface{}) return sp, nil } -func (p *stuckIoProcess) Stdio() (io.Writer, io.Reader, io.Reader) { +func (p *stuckIOProcess) Stdio() (io.Writer, io.Reader, io.Reader) { return p.stdin, p.stdout, p.stderr } -func (p *stuckIoProcess) Close() error { +func (p *stuckIOProcess) Close() error { p.stdin.Close() p.stdout.Close() p.stderr.Close() @@ -249,10 +249,10 @@ func (p *stuckIoProcess) Close() error { } func TestCmdStuckIo(t *testing.T) { - cmd := Command(&stuckIoProcessHost{&localProcessHost{}}, "cmd", "/c", "echo", "hello") + cmd := Command(&stuckIOProcessHost{&localProcessHost{}}, "cmd", "/c", "echo", "hello") cmd.CopyAfterExitTimeout = time.Millisecond * 200 _, err := cmd.Output() - if err != io.ErrClosedPipe { //nolint:errorlint - t.Fatal(err) + if !errors.Is(err, errIOTimeOut) { + t.Fatalf("expected: %v; got: %v", errIOTimeOut, err) } } diff --git a/internal/cmd/io.go b/internal/cmd/io.go index 75ddd1f355..8653bfed21 100644 --- a/internal/cmd/io.go +++ b/internal/cmd/io.go @@ -4,12 +4,18 @@ package cmd import ( "context" + "errors" + "fmt" "io" + "net" "net/url" + "os" "time" - "github.com/pkg/errors" + "github.com/Microsoft/go-winio" "github.com/sirupsen/logrus" + + "github.com/Microsoft/hcsshim/internal/hcs" ) // UpstreamIO is an interface describing the IO to connect to above the shim. @@ -57,13 +63,40 @@ func NewUpstreamIO(ctx context.Context, id, stdout, stderr, stdin string, termin // Create IO for binary logging driver. if u.Scheme != "binary" { - return nil, errors.Errorf("scheme must be 'binary', got: '%s'", u.Scheme) + return nil, fmt.Errorf("scheme must be 'binary', got: '%s'", u.Scheme) } return NewBinaryIO(ctx, id, u) } +// isClosedIOErr checks if the error is from the underlying file or pipe already being closed. +func isClosedIOErr(err error) bool { + for _, e := range []error{ + os.ErrClosed, + net.ErrClosed, + io.ErrClosedPipe, + winio.ErrFileClosed, + hcs.ErrAlreadyClosed, + } { + if errors.Is(err, e) { + return true + } + } + return false +} + // relayIO is a glorified io.Copy that also logs when the copy has completed. +// +// It will ignore errors raised during the copy from attempting to read from +// (or write to) a closed io.Reader (or Writer, respectively). +// Ideally, this would not be necessary, since the command's stdout and stderr would +// send an EOF first before closing, but that is not always the case (eg, [jobcontainer.JobProcess] +// uses unnamed pipes, which do not support EOF). +// Additionally, we do not prevent writing to the stdin of a closed Cmd, so there could be a race +// between reading the upstream stdin, the command finishing, and attempting to write to the command's +// stdin writer. +// +// See [isClosedIOErr] for the errors that are ignored. func relayIO(w io.Writer, r io.Reader, log *logrus.Entry, name string) (int64, error) { n, err := io.Copy(w, r) if log != nil { @@ -72,6 +105,10 @@ func relayIO(w io.Writer, r io.Reader, log *logrus.Entry, name string) (int64, e "file": name, "bytes": n, }) + if isClosedIOErr(err) { + log.WithError(err).Trace("ignoring closed IO error") + err = nil + } if err != nil { lvl = logrus.ErrorLevel log = log.WithError(err)