Skip to content

Commit

Permalink
Fixing running pseudo-tty commands (#3518)
Browse files Browse the repository at this point in the history
* fix: pty exec

* fix: input/output error

* chore: fix lint
  • Loading branch information
levkohimins authored Oct 26, 2024
1 parent a55ed8c commit 45f3160
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 38 deletions.
4 changes: 1 addition & 3 deletions internal/errors/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,7 @@ func UnwrapMultiErrors(err error) []error {
errs = append(errs[:index], errs[index+1:]...)
index--

for _, err := range err.Unwrap() {
errs = append(errs, New(err))
}
errs = append(errs, err.Unwrap()...)

break
}
Expand Down
90 changes: 58 additions & 32 deletions internal/os/exec/ptty_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,95 +4,121 @@
package exec

import (
"context"
"io"
"os"
"os/exec"
"os/signal"
"syscall"

"golang.org/x/sync/errgroup"
"golang.org/x/term"

"github.com/creack/pty"
"github.com/gruntwork-io/terragrunt/internal/errors"
"github.com/gruntwork-io/terragrunt/pkg/log"
"github.com/gruntwork-io/terragrunt/util"
)

// runCommandWithPTY will allocate a pseudo-tty to run the subcommand in. This is only necessary when running
// interactive commands, so that terminal features like readline work through the subcommand when stdin, stdout, and
// stderr is being shared.
// NOTE: This is based on the quickstart example from https://github.com/creack/pty
func runCommandWithPTY(logger log.Logger, cmd *exec.Cmd) (err error) {
cmdStdout := cmd.Stdout

cmd.Stdin = nil
cmd.Stdout = nil
cmd.Stderr = nil

// NOTE: in order to ensure we can return errors that occur in cleanup, we use a variable binding for the return
// value so that it can be updated.
pseudoTerminal, startErr := pty.Start(cmd)
pseudoTerminal, err := pty.Start(cmd)
if err != nil {
return errors.New(err)
}

defer func() {
if closeErr := pseudoTerminal.Close(); closeErr != nil {
logger.Errorf("Error closing pty: %s", closeErr)
closeErr = errors.Errorf("Error closing pty: %w", closeErr)

// Only overwrite the previous error if there was no error since this error has lower priority than any
// errors in the main routine
if err == nil {
err = errors.New(closeErr)
err = closeErr
} else {
logger.Error(closeErr)
}
}
}()

if startErr != nil {
return errors.New(startErr)
}

// Every time the current terminal size changes, we need to make sure the PTY also updates the size.
ch := make(chan os.Signal, 1)
signal.Notify(ch, syscall.SIGWINCH)

go func() {
for range ch {
if inheritSizeErr := pty.InheritSize(os.Stdin, pseudoTerminal); inheritSizeErr != nil {
inheritSizeErr = errors.Errorf("Error resizing pty: %w", inheritSizeErr)

// We don't propagate this error upstream because it does not affect normal operation of the command
logger.Errorf("error resizing pty: %s", inheritSizeErr)
logger.Error(inheritSizeErr)
}
}
}()
ch <- syscall.SIGWINCH // Make sure the pty matches current size

// Set stdin in raw mode so that we preserve readline properties
oldState, setRawErr := term.MakeRaw(int(os.Stdin.Fd()))
if setRawErr != nil {
return errors.New(setRawErr)
oldState, err := term.MakeRaw(int(os.Stdin.Fd()))
if err != nil {
return errors.New(err)
}

defer func() {
if restoreErr := term.Restore(int(os.Stdin.Fd()), oldState); restoreErr != nil {
logger.Errorf("Error restoring terminal state: %s", restoreErr)
restoreErr = errors.Errorf("error restoring terminal state: %w", restoreErr)

// Only overwrite the previous error if there was no error since this error has lower priority than any
// errors in the main routine
if err == nil {
err = errors.New(restoreErr)
err = restoreErr
} else {
logger.Error(restoreErr)
}
}
}()

stdinDone := make(chan error, 1)
// Copy stdin to the pty
go func() {
_, copyStdinErr := io.Copy(pseudoTerminal, os.Stdin)
// We don't propagate this error upstream because it does not affect normal operation of the command. A repeat
// of the same stdin in this case should resolve the issue.
if copyStdinErr != nil {
logger.Errorf("Error forwarding stdin: %s", copyStdinErr)
ctx := context.Background()

ctx, cancel := context.WithCancel(ctx)
defer cancel()

errGroup, ctx := errgroup.WithContext(ctx)

// Copy stdout to the pty.
errGroup.Go(func() error {
defer cancel()

if _, err := util.Copy(ctx, cmdStdout, pseudoTerminal); err != nil {
return errors.Errorf("error forwarding stdout: %w", err)
}
// signal that stdin copy is done
stdinDone <- copyStdinErr
}()

// ... and the pty to stdout.
_, copyStdoutErr := io.Copy(cmd.Stdout, pseudoTerminal)
if copyStdoutErr != nil {
return errors.New(copyStdoutErr)
}
return nil
})

// Copy stdin to the pty.
errGroup.Go(func() error {
defer cancel()

if _, err := util.Copy(ctx, pseudoTerminal, os.Stdin); err != nil {
return errors.Errorf("error forwarding stdin: %w", err)
}

return nil
})

// Wait for stdin copy to complete before returning
if copyStdinErr := <-stdinDone; copyStdinErr != nil && !errors.IsError(copyStdinErr, io.EOF) {
logger.Errorf("Error forwarding stdin: %s", copyStdinErr)
if err := errGroup.Wait(); err != nil && !errors.IsError(err, io.EOF) && !errors.IsContextCanceled(err) {
return errors.New(err)
}

return nil
Expand Down
2 changes: 1 addition & 1 deletion shell/run_shell_cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ func RunShellCommandWithOutput(
exec.WithForwardSignalDelay(SignalForwardingDelay),
)

if err := cmd.Start(); err != nil {
if err := cmd.Start(); err != nil { //nolint:contextcheck
err = util.ProcessExecutionError{
Err: err,
Args: args,
Expand Down
4 changes: 2 additions & 2 deletions terraform/cache/helpers/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
"strconv"

"github.com/gruntwork-io/terragrunt/internal/errors"
"github.com/hashicorp/go-getter/v2"
"github.com/gruntwork-io/terragrunt/util"
)

func Fetch(ctx context.Context, req *http.Request, dst io.Writer) error {
Expand All @@ -34,7 +34,7 @@ func Fetch(ctx context.Context, req *http.Request, dst io.Writer) error {
return err
}

if written, err := getter.Copy(ctx, dst, reader); err != nil {
if written, err := util.Copy(ctx, dst, reader); err != nil {
return errors.New(err)
} else if resp.ContentLength != -1 && written != resp.ContentLength {
return errors.Errorf("incorrect response size: expected %d bytes, but got %d bytes", resp.ContentLength, written)
Expand Down
43 changes: 43 additions & 0 deletions util/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package util

import (
"bytes"
"context"
"crypto/sha256"
"encoding/gob"
"io"
Expand Down Expand Up @@ -786,3 +787,45 @@ func FileSHA256(filePath string) ([]byte, error) {

return hash.Sum(nil), nil
}

// readerFunc is syntactic sugar for read interface.
type readerFunc func(data []byte) (int, error)

func (rf readerFunc) Read(data []byte) (int, error) { return rf(data) }

// writerFunc is syntactic sugar for write interface.
type writerFunc func(data []byte) (int, error)

func (wf writerFunc) Write(data []byte) (int, error) { return wf(data) }

// Copy is a io.Copy cancellable by context.
func Copy(ctx context.Context, dst io.Writer, src io.Reader) (int64, error) {
num, err := io.Copy(
writerFunc(func(data []byte) (int, error) {
select {
case <-ctx.Done():
// context has been canceled stop process and propagate "context canceled" error.
return 0, ctx.Err()
default:
// otherwise just run default io.Writer implementation.
return dst.Write(data)
}
}),
readerFunc(func(data []byte) (int, error) {
select {
case <-ctx.Done():
// context has been canceled stop process and propagate "context canceled" error.
return 0, ctx.Err()
default:
// otherwise just run default io.Reader implementation.
return src.Read(data)
}
}),
)

if err != nil {
err = errors.New(err)
}

return num, err
}

0 comments on commit 45f3160

Please sign in to comment.