Skip to content

Commit

Permalink
Update Cmd IO handling
Browse files Browse the repository at this point in the history
Have `internal\cmd.Cmd` ignore relay and close errors from the
underlying IO channel being closed, since not all
`io.Reader`/`io.Writer`s return an `io.EOF` if the are closed during an
IO operation.

This standardizes behavior between an `hcs`/`gcs` `Process` (which use a
`go-winio.win32File` for their IO channels, and return `io.EOF` for
read and write operations on a closed handle) and `JobProcess` (which
uses an `os.Pipe` that instead return an `os.ErrClosed`).

Additionally, ignore errors from closing an already-closed std IO
channel.

Update `Cmd.Wait` to return a known error value if it times out waiting
on IO copy after the command exits (and update `TestCmdStuckIo` to check
for that error).
Prior, the test checked for an `io.ErrClosedPipe`, which:
1. is not the best indicator that IO is stuck; and
2. is now ignored as an error value raised during IO relay.

Signed-off-by: Hamza El-Saawy <[email protected]>
  • Loading branch information
helsaawy committed Dec 11, 2023
1 parent 7ec8848 commit 51850d7
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 17 deletions.
38 changes: 32 additions & 6 deletions internal/cmd/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package cmd
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"strings"
Expand All @@ -20,6 +21,8 @@ import (
"golang.org/x/sys/windows"
)

var errIOTimeOut = errors.New("timed out waiting for stdio relay")

// CmdProcessRequest stores information on command requests made through this package.
type CmdProcessRequest struct {
Args []string
Expand Down Expand Up @@ -136,9 +139,19 @@ func CommandContext(ctx context.Context, host cow.ProcessHost, name string, arg
// Start starts a command. The caller must ensure that if Start succeeds,
// Wait is eventually called to clean up resources.
func (c *Cmd) Start() error {
if c.Host == nil {
return errors.New("empty ProcessHost")
}

// closed in (*Cmd).Wait; signals command execution is done
c.allDoneCh = make(chan struct{})

var x interface{}
if !c.Host.IsOCI() {
if c.Spec == nil {
return errors.New("process spec is required for non-OCI ProcessHost")
}

wpp := &hcsschema.ProcessParameters{
CommandLine: c.Spec.CommandLine,
User: c.Spec.User.Username,
Expand Down Expand Up @@ -211,7 +224,7 @@ func (c *Cmd) Start() error {
c.stdinErr.Store(err)
}
// Notify the process that there is no more input.
if err := p.CloseStdin(context.TODO()); err != nil && c.Log != nil {
if err := p.CloseStdin(context.TODO()); err != nil && !isClosedIOErr(err) && c.Log != nil {
c.Log.WithError(err).Warn("failed to close Cmd stdin")
}
}()
Expand All @@ -220,8 +233,8 @@ func (c *Cmd) Start() error {
if c.Stdout != nil {
c.iogrp.Go(func() error {
_, err := relayIO(c.Stdout, stdout, c.Log, "stdout")
if err := p.CloseStdout(context.TODO()); err != nil {
c.Log.WithError(err).Warn("failed to close Cmd stdout")
if cErr := p.CloseStdout(context.TODO()); cErr != nil && !isClosedIOErr(cErr) && c.Log != nil {
c.Log.WithError(cErr).Warn("failed to close Cmd stdout")
}
return err
})
Expand All @@ -230,8 +243,8 @@ func (c *Cmd) Start() error {
if c.Stderr != nil {
c.iogrp.Go(func() error {
_, err := relayIO(c.Stderr, stderr, c.Log, "stderr")
if err := p.CloseStderr(context.TODO()); err != nil {
c.Log.WithError(err).Warn("failed to close Cmd stderr")
if cErr := p.CloseStderr(context.TODO()); cErr != nil && !isClosedIOErr(cErr) && c.Log != nil {
c.Log.WithError(cErr).Warn("failed to close Cmd stderr")
}
return err
})
Expand Down Expand Up @@ -270,27 +283,40 @@ func (c *Cmd) Wait() error {
state.exited = true
state.code = code
}

// Terminate the IO if the copy does not complete in the requested time.
timeoutErrCh := make(chan error)
if c.CopyAfterExitTimeout != 0 {
go func() {
defer close(timeoutErrCh)
t := time.NewTimer(c.CopyAfterExitTimeout)
defer t.Stop()
select {
case <-c.allDoneCh:
case <-t.C:
// Close the process to cancel any reads to stdout or stderr.
c.Process.Close()
err := errIOTimeOut
// log the timeout, since we may not return it to the caller
if c.Log != nil {
c.Log.Warn("timed out waiting for stdio relay")
c.Log.WithField("timeout", c.CopyAfterExitTimeout).Warn(err.Error())
}
timeoutErrCh <- err
}
}()
} else {
close(timeoutErrCh)
}

// TODO (go1.20): use multierror for these
ioErr := c.iogrp.Wait()
if ioErr == nil {
ioErr, _ = c.stdinErr.Load().(error)
}
close(c.allDoneCh)
if tErr := <-timeoutErrCh; ioErr == nil {
ioErr = tErr
}
c.Process.Close()
c.ExitState = state
if exitErr != nil {
Expand Down
18 changes: 9 additions & 9 deletions internal/cmd/cmd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,22 +213,22 @@ func TestCmdStdinBlocked(t *testing.T) {
}
}

type stuckIoProcessHost struct {
type stuckIOProcessHost struct {
cow.ProcessHost
}

type stuckIoProcess struct {
type stuckIOProcess struct {
cow.Process
stdin, pstdout, pstderr *io.PipeWriter
pstdin, stdout, stderr *io.PipeReader
}

func (h *stuckIoProcessHost) CreateProcess(ctx context.Context, cfg interface{}) (cow.Process, error) {
func (h *stuckIOProcessHost) CreateProcess(ctx context.Context, cfg interface{}) (cow.Process, error) {
p, err := h.ProcessHost.CreateProcess(ctx, cfg)
if err != nil {
return nil, err
}
sp := &stuckIoProcess{
sp := &stuckIOProcess{
Process: p,
}
sp.pstdin, sp.stdin = io.Pipe()
Expand All @@ -237,22 +237,22 @@ func (h *stuckIoProcessHost) CreateProcess(ctx context.Context, cfg interface{})
return sp, nil
}

func (p *stuckIoProcess) Stdio() (io.Writer, io.Reader, io.Reader) {
func (p *stuckIOProcess) Stdio() (io.Writer, io.Reader, io.Reader) {
return p.stdin, p.stdout, p.stderr
}

func (p *stuckIoProcess) Close() error {
func (p *stuckIOProcess) Close() error {
p.stdin.Close()
p.stdout.Close()
p.stderr.Close()
return p.Process.Close()
}

func TestCmdStuckIo(t *testing.T) {
cmd := Command(&stuckIoProcessHost{&localProcessHost{}}, "cmd", "/c", "echo", "hello")
cmd := Command(&stuckIOProcessHost{&localProcessHost{}}, "cmd", "/c", "echo", "hello")
cmd.CopyAfterExitTimeout = time.Millisecond * 200
_, err := cmd.Output()
if err != io.ErrClosedPipe { //nolint:errorlint
t.Fatal(err)
if !errors.Is(err, errIOTimeOut) {
t.Fatalf("expected: %v; got: %v", errIOTimeOut, err)
}
}
41 changes: 39 additions & 2 deletions internal/cmd/io.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,18 @@ package cmd

import (
"context"
"errors"
"fmt"
"io"
"net"
"net/url"
"os"
"time"

"github.com/pkg/errors"
"github.com/Microsoft/go-winio"
"github.com/sirupsen/logrus"

"github.com/Microsoft/hcsshim/internal/hcs"
)

// UpstreamIO is an interface describing the IO to connect to above the shim.
Expand Down Expand Up @@ -57,13 +63,40 @@ func NewUpstreamIO(ctx context.Context, id, stdout, stderr, stdin string, termin

// Create IO for binary logging driver.
if u.Scheme != "binary" {
return nil, errors.Errorf("scheme must be 'binary', got: '%s'", u.Scheme)
return nil, fmt.Errorf("scheme must be 'binary', got: '%s'", u.Scheme)
}

return NewBinaryIO(ctx, id, u)
}

// isClosedIOErr checks if the error is from the underlying file or pipe already being closed.
func isClosedIOErr(err error) bool {
for _, e := range []error{
os.ErrClosed,
net.ErrClosed,
io.ErrClosedPipe,
winio.ErrFileClosed,
hcs.ErrAlreadyClosed,
} {
if errors.Is(err, e) {
return true
}
}
return false
}

// relayIO is a glorified io.Copy that also logs when the copy has completed.
//
// It will ignore errors raised during the copy from attempting to read from
// (or write to) a closed io.Reader (or Writer, respectively).
// Ideally, this would not be necessary, since the command's stdout and stderr would
// send an EOF first before closing, but that is not always the case (eg, [jobcontainer.JobProcess]
// uses unnamed pipes, which do not support EOF).
// Additionally, we do not prevent writing to the stdin of a closed Cmd, so there could be a race
// between reading the upstream stdin, the command finishing, and attempting to write to the command's
// stdin writer.
//
// See [isClosedIOErr] for the errors that are ignored.
func relayIO(w io.Writer, r io.Reader, log *logrus.Entry, name string) (int64, error) {
n, err := io.Copy(w, r)
if log != nil {
Expand All @@ -72,6 +105,10 @@ func relayIO(w io.Writer, r io.Reader, log *logrus.Entry, name string) (int64, e
"file": name,
"bytes": n,
})
if isClosedIOErr(err) {
log.WithError(err).Trace("ignoring closed IO error")
err = nil
}
if err != nil {
lvl = logrus.ErrorLevel
log = log.WithError(err)
Expand Down

0 comments on commit 51850d7

Please sign in to comment.