Skip to content

Commit

Permalink
Create a proper shell through exec (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
Erik Jansson authored Feb 14, 2022
1 parent a9d295f commit f415d20
Showing 1 changed file with 52 additions and 8 deletions.
60 changes: 52 additions & 8 deletions pkg/shell/shell.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"os"
"os/exec"
"os/signal"
"syscall"
"time"

Expand All @@ -19,20 +20,63 @@ func ShellWithCredentials(profile string, creds aws.Credentials) error {
return err
}

os.Setenv("AWS_ACCESS_KEY_ID", creds.AccessKeyID)
os.Setenv("AWS_SECRET_ACCESS_KEY", creds.SecretAccessKey)
os.Setenv("AWS_SESSION_TOKEN", creds.SessionToken)
os.Setenv("AWS_SECURITY_TOKEN", creds.SessionToken)
os.Setenv("ASSUMED_PROFILE", profile)
env := os.Environ()

env = append(env, []string{
fmt.Sprintf("AWS_ACCESS_KEY_ID=%s", creds.AccessKeyID),
fmt.Sprintf("AWS_SECRET_ACCESS_KEY=%s", creds.SecretAccessKey),
fmt.Sprintf("AWS_SESSION_TOKEN=%s", creds.SessionToken),
fmt.Sprintf("AWS_SECURITY_TOKEN=%s", creds.SessionToken),
fmt.Sprintf("ASSUMED_PROFILE=%s", profile),
}...)

var expires string
if creds.CanExpire {
os.Setenv("ASSUMED_PROFILE_EXPIRES", fmt.Sprintf("%d", creds.Expires.Unix()))
env = append(env, fmt.Sprintf("ASSUMED_PROFILE_EXPIRES=%d", creds.Expires.Unix()))
expires = fmt.Sprintf("expires in %s", durafmt.Parse(time.Until(creds.Expires)).LimitFirstN(2))
} else {
expires = "never expires"
}
env := os.Environ()

fmt.Printf("Exported assumed role credentials for profile %s, %s\n", profile, expires)
return syscall.Exec(argv0, []string{}, env)

cmd := exec.Command(argv0)
cmd.Env = env
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr

inputSignal := make(chan os.Signal, 1)
signal.Notify(inputSignal, os.Interrupt, syscall.SIGTERM)

if err := cmd.Start(); err != nil {
return err
}

waitChannel := make(chan error, 1)
go func() {
waitChannel <- cmd.Wait()
close(waitChannel)
}()

for {

select {
case sig := <-inputSignal:
if err := cmd.Process.Signal(sig); err != nil {
return err
}
case err := <-waitChannel:
var waitStatus syscall.WaitStatus
if exitError, ok := err.(*exec.ExitError); ok {
waitStatus = exitError.Sys().(syscall.WaitStatus)
os.Exit(waitStatus.ExitStatus())
}
if err != nil {
return err
}
return nil
}
}

}

0 comments on commit f415d20

Please sign in to comment.