Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow vault ssh to accept ssh commands in any ssh compatible format #4710

Merged
merged 2 commits into from
Jun 14, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 125 additions & 51 deletions command/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,17 +243,27 @@ func (c *SSHCommand) Run(args []string) int {
return 1
}

// Extract the username and IP.
username, hostname, ip, err := c.userHostAndIP(args[0])
// Extract the hostname, username and port from the ssh command
hostname, username, port, err := c.parseSSHCommand(args)
if err != nil {
c.UI.Error(fmt.Sprintf("Error parsing user and IP: %s", err))
c.UI.Error(fmt.Sprintf("Error parsing the ssh command: %q", err))
return 1
}

// The rest of the args are ssh args
sshArgs := []string{}
if len(args) > 1 {
sshArgs = args[1:]
// Use the current user if no user was specified in the ssh command
if username == "" {
u, err := user.Current()
if err != nil {
c.UI.Error(fmt.Sprintf("Error getting the current user: %q", err))
return 1
}
username = u.Username
}

ip, err := c.resolveHostname(hostname)
if err != nil {
c.UI.Error(fmt.Sprintf("Error resolving the ssh hostname: %q", err))
return 1
}

// Set the client in the command
Expand Down Expand Up @@ -329,19 +339,19 @@ func (c *SSHCommand) Run(args []string) int {

switch strings.ToLower(c.flagMode) {
case ssh.KeyTypeCA:
return c.handleTypeCA(username, hostname, ip, sshArgs)
return c.handleTypeCA(username, ip, port, args)
case ssh.KeyTypeOTP:
return c.handleTypeOTP(username, hostname, ip, sshArgs)
return c.handleTypeOTP(username, ip, port, args)
case ssh.KeyTypeDynamic:
return c.handleTypeDynamic(username, ip, sshArgs)
return c.handleTypeDynamic(username, ip, port, args)
default:
c.UI.Error(fmt.Sprintf("Unknown SSH mode: %s", c.flagMode))
return 1
}
}

// handleTypeCA is used to handle SSH logins using the "CA" key type.
func (c *SSHCommand) handleTypeCA(username, hostname, ip string, sshArgs []string) int {
func (c *SSHCommand) handleTypeCA(username, ip, port string, sshArgs []string) int {
// Read the key from disk
publicKey, err := ioutil.ReadFile(c.flagPublicKeyPath)
if err != nil {
Expand Down Expand Up @@ -460,10 +470,6 @@ func (c *SSHCommand) handleTypeCA(username, hostname, ip string, sshArgs []strin
)
}

args = append(args,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this removed entirely? We don't need to pass in the username, IP, and port to the SSH command explicitly any more?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. My comment here should clear things up for you: #4710 (comment)

The current behaviour was to expect username@hostname as the first argument. This was then added to the actual ssh command with username@hostname plus the remaining ssh arguments. This PR removes that limitation of needing to use username@hostname and instead allows you to use any valid ssh command so that vault doesn't need to modify the actual command.

username+"@"+hostname,
)

// Add extra user defined ssh arguments
args = append(args, sshArgs...)

Expand Down Expand Up @@ -493,7 +499,7 @@ func (c *SSHCommand) handleTypeCA(username, hostname, ip string, sshArgs []strin
}

// handleTypeOTP is used to handle SSH logins using the "otp" key type.
func (c *SSHCommand) handleTypeOTP(username, hostname string, ip string, sshArgs []string) int {
func (c *SSHCommand) handleTypeOTP(username, ip, port string, sshArgs []string) int {
secret, cred, err := c.generateCredential(username, ip)
if err != nil {
c.UI.Error(fmt.Sprintf("failed to generate credential: %s", err))
Expand Down Expand Up @@ -543,10 +549,13 @@ func (c *SSHCommand) handleTypeOTP(username, hostname string, ip string, sshArgs
)
}

// If a port wasn't specified in the ssh arguments lets use the port we got back from vault
if port == "" {
args = append(args, "-p", cred.Port)
}

args = append(args,
"-o StrictHostKeyChecking="+c.flagStrictHostKeyChecking,
"-p", cred.Port,
username+"@"+hostname,
)

// Add the rest of the ssh args appended by the user
Expand Down Expand Up @@ -585,7 +594,7 @@ func (c *SSHCommand) handleTypeOTP(username, hostname string, ip string, sshArgs
}

// handleTypeDynamic is used to handle SSH logins using the "dyanmic" key type.
func (c *SSHCommand) handleTypeDynamic(username, ip string, sshArgs []string) int {
func (c *SSHCommand) handleTypeDynamic(username, ip, port string, sshArgs []string) int {
// Generate the credential
secret, cred, err := c.generateCredential(username, ip)
if err != nil {
Expand All @@ -610,13 +619,20 @@ func (c *SSHCommand) handleTypeDynamic(username, ip string, sshArgs []string) in
return 1
}

args := append([]string{
args := make([]string, 0)
// If a port wasn't specified in the ssh arguments lets use the port we got back from vault
if port == "" {
args = append(args, "-p", cred.Port)
}

args = append(args,
"-i", keyPath,
"-o UserKnownHostsFile=" + c.flagUserKnownHostsFile,
"-o StrictHostKeyChecking=" + c.flagStrictHostKeyChecking,
"-p", cred.Port,
username + "@" + ip,
}, sshArgs...)
"-o UserKnownHostsFile="+c.flagUserKnownHostsFile,
"-o StrictHostKeyChecking="+c.flagStrictHostKeyChecking,
)

// Add extra user defined ssh arguments
args = append(args, sshArgs...)

cmd := exec.Command("ssh", args...)
cmd.Stdin = os.Stdin
Expand Down Expand Up @@ -745,37 +761,95 @@ func (c *SSHCommand) defaultRole(mountPoint, ip string) (string, error) {
}
}

// userAndIP takes an argument in the format [email protected] and separates the IP
// and user parts, returning any errors.
func (c *SSHCommand) userHostAndIP(s string) (string, string, string, error) {
// split the parameter username@ip
input := strings.Split(s, "@")
var username, address string

// If only IP is mentioned and username is skipped, assume username to
// be the current username. Vault SSH role's default username could have
// been used, but in order to retain the consistency with SSH command,
// current username is employed.
switch len(input) {
case 1:
u, err := user.Current()
if err != nil {
return "", "", "", errors.Wrap(err, "failed to fetch current user")
// Finds the hostname, username (optional) and port (optional) from any valid ssh command
// Supports usrname@hostname but also specifying valid ssh flags like -o User=username,
// -o Port=2222 and -p 2222 anywhere in the command
func (c *SSHCommand) parseSSHCommand(args []string) (hostname string, username string, port string, err error) {
lastArg := ""

for _, i := range args {
arg := lastArg
lastArg = ""

// If -p has been specified then this is our ssh port
if arg == "-p" {
port = i
continue
}
username, address = u.Username, input[0]
case 2:
username, address = input[0], input[1]
default:
return "", "", "", fmt.Errorf("invalid arguments: %q", s)

// this is an ssh option, lets see if User or Port have been set and use it
if arg == "-o" {
split := strings.Split(i, "=")
key := split[0]
// Incase the value contains = signs we want to get all of them
value := strings.Join(split[1:], " ")

if key == "User" {
// Don't overwrite the user if it is already set by username@hostname
// This matches the behaviour for how regular ssh reponds when both are specified
if username == "" {
username = value
}
}

if key == "Port" {
// Don't overwrite the port if it is already set by -p
// This matches the behaviour for how regular ssh reponds when both are specified
if port == "" {
port = value
}
}
continue
}

// This isn't an ssh argument that we care about. Lets keep on parsing the command
if arg != "" {
continue
}

// If this is an ssh argument we want to look at the value
if strings.HasPrefix(i, "-") {
lastArg = i
continue
}

// If we have gotten this far it means this is a bare argument
// The first bare argument is the hostname
// The second bare argument is the command to run on the remote host

// If the hostname hasn't been set yet than it means we have found the first bare argument
if hostname == "" {
if strings.Contains(i, "@") {
split := strings.Split(i, "@")
username = split[0]
hostname = split[1]
} else {
hostname = i
}
continue
} else {
// The second bare argument is the command to run on the remote host.
// We need to break out and stop parsing arugments now
break
}

}
if hostname == "" {
return "", "", "", errors.Wrap(
err,
fmt.Sprintf("failed to find a hostname in ssh command %q", strings.Join(args, " ")),
)
}
return hostname, username, port, nil
}

func (c *SSHCommand) resolveHostname(hostname string) (ip string, err error) {
// Resolving domain names to IP address on the client side.
// Vault only deals with IP addresses.
ipAddr, err := net.ResolveIPAddr("ip", address)
ipAddr, err := net.ResolveIPAddr("ip", hostname)
if err != nil {
return "", "", "", errors.Wrap(err, "failed to resolve IP address")
return "", errors.Wrap(err, "failed to resolve IP address")
}
ip := ipAddr.String()

return username, address, ip, nil
ip = ipAddr.String()
return ip, nil
}
133 changes: 133 additions & 0 deletions command/ssh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,136 @@ func TestSSHCommand_Run(t *testing.T) {
t.Parallel()
t.Skip("Need a way to setup target infrastructure")
}

func TestParseSSHCommand(t *testing.T) {
t.Parallel()

_, cmd := testSSHCommand(t)
var tests = []struct {
name string
args []string
hostname string
username string
port string
err error
}{
{
"Parse just a hostname",
[]string{
"hostname",
},
"hostname",
"",
"",
nil,
},
{
"Parse the standard username@hostname",
[]string{
"username@hostname",
},
"hostname",
"username",
"",
nil,
},
{
"Parse the username out of -o User=username",
[]string{
"-o", "User=username",
"hostname",
},
"hostname",
"username",
"",
nil,
},
{
"If the username is specified with -o User=username and realname@hostname prefer realname@",
[]string{
"-o", "User=username",
"realname@hostname",
},
"hostname",
"realname",
"",
nil,
},
{
"Parse the port out of -o Port=2222",
[]string{
"-o", "Port=2222",
"hostname",
},
"hostname",
"",
"2222",
nil,
},
{
"Parse the port out of -p 2222",
[]string{
"-p", "2222",
"hostname",
},
"hostname",
"",
"2222",
nil,
},
{
"If port is defined with -o Port=2222 and -p 2244 prefer -p",
[]string{
"-p", "2244",
"-o", "Port=2222",
"hostname",
},
"hostname",
"",
"2244",
nil,
},
{
"Ssh args with a command",
[]string{
"hostname",
"command",
},
"hostname",
"",
"",
nil,
},
{
"Flags after the ssh command are not pased because they are part of the command",
[]string{
"username@hostname",
"command",
"-p 22",
},
"hostname",
"username",
"",
nil,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {

hostname, username, port, err := cmd.parseSSHCommand(test.args)
if err != test.err {
t.Errorf("got error: %q want %q", err, test.err)
}
if hostname != test.hostname {
t.Errorf("got hostname: %q want %q", hostname, test.hostname)
}
if username != test.username {
t.Errorf("got username: %q want %q", username, test.username)
}
if port != test.port {
t.Errorf("got port: %q want %q", port, test.port)
}
})
}
}