Skip to content

Commit

Permalink
Merge pull request #1165 from 99designs/env-var-to-disable-help-message
Browse files Browse the repository at this point in the history
Env var to disable help message
  • Loading branch information
mtibben authored Feb 26, 2023
2 parents 97de797 + 62984c3 commit ea3c3d6
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 98 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ jonsmith jonsmith -

# Start a subshell with temporary credentials
$ aws-vault exec jonsmith
aws-vault: Starting a subshell /bin/zsh, use `exit` to exit the subshell
Starting subshell /bin/zsh, use `exit` to exit the subshell
$ aws s3 ls
bucket_1
bucket_2
Expand Down
154 changes: 75 additions & 79 deletions cli/exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,18 @@ import (
)

type ExecCommandInput struct {
ProfileName string
Command string
Args []string
StartEc2Server bool
StartEcsServer bool
Lazy bool
JSONDeprecated bool
Config vault.Config
SessionDuration time.Duration
NoSession bool
UseStdout bool
ProfileName string
Command string
Args []string
StartEc2Server bool
StartEcsServer bool
Lazy bool
JSONDeprecated bool
Config vault.Config
SessionDuration time.Duration
NoSession bool
UseStdout bool
ShowHelpMessages bool
}

func (input ExecCommandInput) validate() error {
Expand Down Expand Up @@ -122,6 +123,7 @@ func ConfigureExecCommand(app *kingpin.Application, a *AwsVault) {
input.Config.NonChainedGetSessionTokenDuration = input.SessionDuration
input.Config.AssumeRoleDuration = input.SessionDuration
input.Config.SSOUseStdout = input.UseStdout
input.ShowHelpMessages = input.Command == "" && isATerminal() && os.Getenv("AWS_VAULT_DISABLE_HELP_MESSAGE") != "1"

f, err := a.AwsConfigFile()
if err != nil {
Expand Down Expand Up @@ -153,44 +155,75 @@ func ConfigureExecCommand(app *kingpin.Application, a *AwsVault) {

func ExecCommand(input ExecCommandInput, f *vault.ConfigFile, keyring keyring.Keyring) error {
if os.Getenv("AWS_VAULT") != "" {
return fmt.Errorf("in an existing aws-vault subshell; 'exit' from the subshell or unset AWS_VAULT to force")
return fmt.Errorf("running in an existing aws-vault subshell; 'exit' from the subshell or unset AWS_VAULT to force")
}

err := input.validate()
if err != nil {
if err := input.validate(); err != nil {
return err
}

vault.UseSession = !input.NoSession

configLoader := vault.ConfigLoader{
File: f,
BaseConfig: input.Config,
ActiveProfile: input.ProfileName,
}
config, err := configLoader.LoadFromProfile(input.ProfileName)
config, err := vault.NewConfigLoader(input.Config, f, input.ProfileName).LoadFromProfile(input.ProfileName)
if err != nil {
return fmt.Errorf("Error loading config: %w", err)
}

ckr := &vault.CredentialKeyring{Keyring: keyring}
credsProvider, err := vault.NewTempCredentialsProvider(config, ckr)
credsProvider, err := vault.NewTempCredentialsProvider(config, &vault.CredentialKeyring{Keyring: keyring})
if err != nil {
return fmt.Errorf("Error getting temporary credentials: %w", err)
}

subshellHelp := ""
if input.Command == "" {
input.Command = getDefaultShell()
subshellHelp = fmt.Sprintf("Starting subshell %s, use `exit` to exit the subshell", input.Command)
}

cmdEnv := createEnv(input.ProfileName, config.Region)

if input.StartEc2Server {
return execEc2Server(input, config, credsProvider)
printHelpMessage("Starting an EC2 credential server on 169.254.169.254:80", input.ShowHelpMessages)
if err = server.StartEc2CredentialsServer(context.TODO(), credsProvider, config.Region); err != nil {
return fmt.Errorf("Failed to start credential server: %w", err)
}
printHelpMessage(subshellHelp, input.ShowHelpMessages)
} else if input.StartEcsServer {
printHelpMessage("Starting an ECS credential server; your app's AWS sdk must support AWS_CONTAINER_CREDENTIALS_FULL_URI.", input.ShowHelpMessages)
if err = startEcsServerAndSetEnv(credsProvider, config, input.Lazy, &cmdEnv); err != nil {
return err
}
printHelpMessage(subshellHelp, input.ShowHelpMessages)
} else {
if err = addCredsToEnv(credsProvider, input.ProfileName, &cmdEnv); err != nil {
return err
}
printHelpMessage(subshellHelp, input.ShowHelpMessages)

if osSupportsExecSyscall() {
return doExecSyscall(input.Command, input.Args, cmdEnv)
}
}

if input.StartEcsServer {
return execEcsServer(input, config, credsProvider)
return runChildProcess(input.Command, input.Args, cmdEnv)
}

func printHelpMessage(helpMsg string, showHelpMessages bool) {
if helpMsg != "" {
if showHelpMessages {
printToStderr(helpMsg)
} else {
log.Println(helpMsg)
}
}
}

return execEnvironment(input, config, credsProvider)
func printToStderr(helpMsg string) {
fmt.Fprint(os.Stderr, helpMsg, "\n")
}

func updateEnvForAwsVault(env environ, profileName string, region string) environ {
func createEnv(profileName string, region string) environ {
env := environ(os.Environ())
env.Unset("AWS_ACCESS_KEY_ID")
env.Unset("AWS_SECRET_ACCESS_KEY")
env.Unset("AWS_SESSION_TOKEN")
Expand All @@ -213,20 +246,8 @@ func updateEnvForAwsVault(env environ, profileName string, region string) enviro
return env
}

func execEc2Server(input ExecCommandInput, config *vault.Config, credsProvider aws.CredentialsProvider) error {
fmt.Fprintf(os.Stderr, "aws-vault: Starting an EC2 credential server.\n")
if err := server.StartEc2CredentialsServer(context.TODO(), credsProvider, config.Region); err != nil {
return fmt.Errorf("Failed to start credential server: %w", err)
}

env := environ(os.Environ())
env = updateEnvForAwsVault(env, input.ProfileName, config.Region)

return doRunCmd(input.Command, input.Args, env)
}

func execEcsServer(input ExecCommandInput, config *vault.Config, credsProvider aws.CredentialsProvider) error {
ecsServer, err := server.NewEcsServer(context.TODO(), credsProvider, config, "", 0, input.Lazy)
func startEcsServerAndSetEnv(credsProvider aws.CredentialsProvider, config *vault.Config, lazy bool, cmdEnv *environ) error {
ecsServer, err := server.NewEcsServer(context.TODO(), credsProvider, config, "", 0, lazy)
if err != nil {
return err
}
Expand All @@ -238,48 +259,32 @@ func execEcsServer(input ExecCommandInput, config *vault.Config, credsProvider a
}()

log.Println("Setting subprocess env AWS_CONTAINER_CREDENTIALS_FULL_URI, AWS_CONTAINER_AUTHORIZATION_TOKEN")
env := environ(os.Environ())
env = updateEnvForAwsVault(env, input.ProfileName, config.Region)
env.Set("AWS_CONTAINER_CREDENTIALS_FULL_URI", ecsServer.BaseURL())
env.Set("AWS_CONTAINER_AUTHORIZATION_TOKEN", ecsServer.AuthToken())
cmdEnv.Set("AWS_CONTAINER_CREDENTIALS_FULL_URI", ecsServer.BaseURL())
cmdEnv.Set("AWS_CONTAINER_AUTHORIZATION_TOKEN", ecsServer.AuthToken())

helpMsg := "Started an ECS credential server; your app's AWS sdk must support AWS_CONTAINER_CREDENTIALS_FULL_URI."
if input.Command == "" {
fmt.Fprintf(os.Stderr, "aws-vault: %s\n", helpMsg)
} else {
log.Println(helpMsg)
}

return doRunCmd(input.Command, input.Args, env)
return nil
}

func execEnvironment(input ExecCommandInput, config *vault.Config, credsProvider aws.CredentialsProvider) error {
func addCredsToEnv(credsProvider aws.CredentialsProvider, profileName string, cmdEnv *environ) error {
creds, err := credsProvider.Retrieve(context.TODO())
if err != nil {
return fmt.Errorf("Failed to get credentials for %s: %w", input.ProfileName, err)
return fmt.Errorf("Failed to get credentials for %s: %w", profileName, err)
}

env := environ(os.Environ())
env = updateEnvForAwsVault(env, input.ProfileName, config.Region)

log.Println("Setting subprocess env: AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY")
env.Set("AWS_ACCESS_KEY_ID", creds.AccessKeyID)
env.Set("AWS_SECRET_ACCESS_KEY", creds.SecretAccessKey)
cmdEnv.Set("AWS_ACCESS_KEY_ID", creds.AccessKeyID)
cmdEnv.Set("AWS_SECRET_ACCESS_KEY", creds.SecretAccessKey)

if creds.SessionToken != "" {
log.Println("Setting subprocess env: AWS_SESSION_TOKEN")
env.Set("AWS_SESSION_TOKEN", creds.SessionToken)
cmdEnv.Set("AWS_SESSION_TOKEN", creds.SessionToken)
}
if creds.CanExpire {
log.Println("Setting subprocess env: AWS_CREDENTIAL_EXPIRATION")
env.Set("AWS_CREDENTIAL_EXPIRATION", iso8601.Format(creds.Expires))
}

if !supportsExecSyscall() {
return doRunCmd(input.Command, input.Args, env)
cmdEnv.Set("AWS_CREDENTIAL_EXPIRATION", iso8601.Format(creds.Expires))
}

return doExecSyscall(input.Command, input.Args, env)
return nil
}

// environ is a slice of strings representing the environment, in the form "key=value".
Expand Down Expand Up @@ -314,12 +319,7 @@ func getDefaultShell() string {
return command
}

func doRunCmd(command string, args []string, env []string) error {
if command == "" {
command = getDefaultShell()
fmt.Fprintf(os.Stderr, "aws-vault: Starting a subshell %s, use `exit` to exit the subshell\n", command)
}

func runChildProcess(command string, args []string, env []string) error {
log.Printf("Starting subprocess: %s %s", command, strings.Join(args, " "))

cmd := osexec.Command(command, args...)
Expand All @@ -335,6 +335,7 @@ func doRunCmd(command string, args []string, env []string) error {
return err
}

// proxy signals to process
go func() {
for {
sig := <-sigChan
Expand All @@ -352,16 +353,11 @@ func doRunCmd(command string, args []string, env []string) error {
return nil
}

func supportsExecSyscall() bool {
func osSupportsExecSyscall() bool {
return runtime.GOOS == "linux" || runtime.GOOS == "darwin" || runtime.GOOS == "freebsd" || runtime.GOOS == "openbsd"
}

func doExecSyscall(command string, args []string, env []string) error {
if command == "" {
command = getDefaultShell()
fmt.Fprintf(os.Stderr, "aws-vault: Starting a subshell %s\n", command)
}

log.Printf("Exec command %s %s", command, strings.Join(args, " "))

argv0, err := osexec.LookPath(command)
Expand Down
7 changes: 1 addition & 6 deletions cli/export.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,7 @@ func ExportCommand(input ExportCommandInput, f *vault.ConfigFile, keyring keyrin

vault.UseSession = !input.NoSession

configLoader := vault.ConfigLoader{
File: f,
BaseConfig: input.Config,
ActiveProfile: input.ProfileName,
}
config, err := configLoader.LoadFromProfile(input.ProfileName)
config, err := vault.NewConfigLoader(input.Config, f, input.ProfileName).LoadFromProfile(input.ProfileName)
if err != nil {
return fmt.Errorf("Error loading config: %w", err)
}
Expand Down
3 changes: 2 additions & 1 deletion cli/global.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ type AwsVault struct {
}

func isATerminal() bool {
return isatty.IsTerminal(os.Stdout.Fd()) || isatty.IsCygwinTerminal(os.Stdout.Fd())
fd := os.Stdout.Fd()
return isatty.IsTerminal(fd) || isatty.IsCygwinTerminal(fd)
}

func (a *AwsVault) PromptDriver(avoidTerminalPrompt bool) string {
Expand Down
7 changes: 1 addition & 6 deletions cli/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,7 @@ func ConfigureLoginCommand(app *kingpin.Application, a *AwsVault) {
func LoginCommand(input LoginCommandInput, f *vault.ConfigFile, keyring keyring.Keyring) error {
vault.UseSession = !input.NoSession

configLoader := vault.ConfigLoader{
File: f,
BaseConfig: input.Config,
ActiveProfile: input.ProfileName,
}
config, err := configLoader.LoadFromProfile(input.ProfileName)
config, err := vault.NewConfigLoader(input.Config, f, input.ProfileName).LoadFromProfile(input.ProfileName)
if err != nil {
return fmt.Errorf("Error loading config: %w", err)
}
Expand Down
6 changes: 1 addition & 5 deletions cli/rotate.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,7 @@ func RotateCommand(input RotateCommandInput, f *vault.ConfigFile, keyring keyrin
vault.UseSession = !input.NoSession
vault.UseSessionCache = false

configLoader := &vault.ConfigLoader{
File: f,
BaseConfig: input.Config,
ActiveProfile: input.ProfileName,
}
configLoader := vault.NewConfigLoader(input.Config, f, input.ProfileName)
config, err := configLoader.LoadFromProfile(input.ProfileName)
if err != nil {
return fmt.Errorf("Error loading config: %w", err)
Expand Down
8 changes: 8 additions & 0 deletions vault/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,14 @@ type ConfigLoader struct {
visitedProfiles []string
}

func NewConfigLoader(baseConfig Config, file *ConfigFile, activeProfile string) *ConfigLoader {
return &ConfigLoader{
BaseConfig: baseConfig,
File: file,
ActiveProfile: activeProfile,
}
}

func (cl *ConfigLoader) visitProfile(name string) bool {
for _, p := range cl.visitedProfiles {
if p == name {
Expand Down

0 comments on commit ea3c3d6

Please sign in to comment.