diff --git a/interp/api.go b/interp/api.go index 0180394d5..930c320d3 100644 --- a/interp/api.go +++ b/interp/api.go @@ -528,6 +528,11 @@ func (r *Runner) Run(ctx context.Context, node syntax.Node) error { if !r.didReset { r.Reset() } + if r.stdin != nil { + if _, ok := r.stdin.(Canceler); !ok { + r.stdin = NewCancelableReader(ctx, r.stdin) + } + } r.fillExpandConfig(ctx) r.err = nil r.shellExited = false diff --git a/interp/reader.go b/interp/reader.go new file mode 100644 index 000000000..40bbac5c6 --- /dev/null +++ b/interp/reader.go @@ -0,0 +1,112 @@ +package interp + +import ( + "context" + "io" + "sync" +) + +var _ io.Reader = (*CancelableReader)(nil) +var _ io.Reader = (*CancelableReaderTTY)(nil) +var _ Canceler = (*CancelableReader)(nil) +var _ Canceler = (*CancelableReaderTTY)(nil) +var _ fder = (*CancelableReaderTTY)(nil) + +type CancelableReader struct { + ctx context.Context + cancel context.CancelFunc + in chan []byte + out chan readResult + err error + r io.Reader + once sync.Once +} + +type CancelableReaderTTY struct { + CancelableReader +} + +type Canceler interface { + Cancel() +} + +type readResult struct { + n int + err error +} + +func NewCancelableReader(ctx context.Context, r io.Reader) io.Reader { + ctx, cancel := context.WithCancel(ctx) + c := CancelableReader{ + r: r, + ctx: ctx, + cancel: cancel, + in: make(chan []byte), + out: make(chan readResult), + } + // Make sure [[ -t 0 ]] still works + if _, ok := r.(fder); ok { + return &CancelableReaderTTY{ + CancelableReader: c, + } + } + return &c +} + +// Read implements the io.Reader interface +func (c *CancelableReader) Read(p []byte) (int, error) { + c.once.Do(func() { go c.begin() }) + + // Send the buffer over to the reader goroutine + select { + case c.in <- p: + case <-c.ctx.Done(): + return 0, c.ctx.Err() + } + + // Get the output from the reader goroutine. + select { + case res, ok := <-c.out: + if !ok { + return 0, c.err + } + return res.n, res.err + case <-c.ctx.Done(): + return 0, c.ctx.Err() + } +} + +// Close implements the io.Closer interface. +func (c *CancelableReader) Close() error { + close(c.in) + if closer, ok := c.r.(io.Closer); ok { + return closer.Close() + } + return nil +} + +func (c *CancelableReader) begin() { + for c.ctx.Err() == nil { + select { + case buf, ok := <-c.in: + if !ok { + return + } + n, err := c.r.Read(buf) + select { + case c.out <- readResult{n: n, err: err}: + case <-c.ctx.Done(): + return + } + case <-c.ctx.Done(): + } + } +} + +func (c *CancelableReader) Cancel() { + c.cancel() +} + +func (ct *CancelableReaderTTY) Fd() uintptr { + return ct.r.(fder).Fd() +} diff --git a/interp/test.go b/interp/test.go index 156370196..94a9fa904 100644 --- a/interp/test.go +++ b/interp/test.go @@ -16,6 +16,10 @@ import ( "mvdan.cc/sh/v3/syntax" ) +type fder interface { + Fd() uintptr +} + // non-empty string is true, empty string is false func (r *Runner) bashTest(ctx context.Context, expr syntax.TestExpr, classic bool) string { switch x := expr.(type) { @@ -178,7 +182,7 @@ func (r *Runner) unTest(ctx context.Context, op syntax.UnTestOperator, x string) case 2: f = r.stderr } - if f, ok := f.(interface{ Fd() uintptr }); ok { + if f, ok := f.(fder); ok { // Support Fd methods such as the one on *os.File. return term.IsTerminal(int(f.Fd())) }