diff --git a/pkg/ssh/ssh_test.go b/pkg/ssh/ssh_test.go index dc8608e60..b7b6eddbe 100644 --- a/pkg/ssh/ssh_test.go +++ b/pkg/ssh/ssh_test.go @@ -61,6 +61,19 @@ func TestDial(t *testing.T) { _, err = Dial(&options, GolangMode) require.Error(t, err, "failed to connect: ssh: handshake failed: ssh: disconnect, reason 2: Too many authentication failures") + + // Test again without specifying sshd port, and code should default to port 22 + options = ConnectionDialOptions{ + Host: "localhost", + } + + _, err = Dial(&options, NativeMode) + // exit status 255 is what you get when ssh is not enabled or the connection failed + // this means up to that point, everything worked + require.Error(t, err, "exit status 255") + + _, err = Dial(&options, GolangMode) + require.Error(t, err, "failed to connect: ssh: handshake failed: ssh: disconnect, reason 2: Too many authentication failures") } func TestScp(t *testing.T) { diff --git a/pkg/ssh/utils.go b/pkg/ssh/utils.go index a19468d3a..051e2f758 100644 --- a/pkg/ssh/utils.go +++ b/pkg/ssh/utils.go @@ -15,6 +15,8 @@ import ( "golang.org/x/term" ) +const sshdPort = 22 + func Validate(user *url.Userinfo, path string, port int, identity string) (*config.Destination, *url.URL, error) { // url.Parse NEEDS ssh://, if this ever fails or returns some nonsense, that is why. uri, err := url.Parse(path) @@ -28,11 +30,10 @@ func Validate(user *url.Userinfo, path string, port int, identity string) (*conf } if uri.Port() == "" { - if port != 0 { - uri.Host = net.JoinHostPort(uri.Host, strconv.Itoa(port)) - } else { - uri.Host = net.JoinHostPort(uri.Host, "22") + if port == 0 { + port = sshdPort } + uri.Host = net.JoinHostPort(uri.Host, strconv.Itoa(port)) } if user != nil { @@ -165,11 +166,15 @@ func ParseScpArgs(options ConnectionScpOptions) (string, string, string, bool, e } func DialNet(sshClient *ssh.Client, mode string, url *url.URL) (net.Conn, error) { - port, err := strconv.Atoi(url.Port()) - if err != nil { - return nil, err + port := sshdPort + if url.Port() != "" { + p, err := strconv.Atoi(url.Port()) + if err != nil { + return nil, err + } + port = p } - if _, _, err = Validate(url.User, url.Hostname(), port, ""); err != nil { + if _, _, err := Validate(url.User, url.Hostname(), port, ""); err != nil { return nil, err } return sshClient.Dial(mode, url.Path)