diff --git a/modules/ssh/ssh.go b/modules/ssh/ssh.go index e6b220b9e..147363550 100644 --- a/modules/ssh/ssh.go +++ b/modules/ssh/ssh.go @@ -261,6 +261,32 @@ func CheckSshCommandE(t testing.TestingT, host Host, command string) (string, er return runSSHCommand(t, sshSession) } +// CheckSshCommandWithRetry checks that you can connect via SSH to the given host and run the given command until max retries have been exceeded. Returns the stdout/stderr. +func CheckSshCommandWithRetry(t testing.TestingT, host Host, command string, retries int, sleepBetweenRetries time.Duration, f ...func(testing.TestingT, Host, string) (string, error)) string { + handler := CheckSshCommandE + if f != nil { + handler = f[0] + } + out, err := CheckSshCommandWithRetryE(t, host, command, retries, sleepBetweenRetries, handler) + if err != nil { + t.Fatal(err) + } + return out +} + +// CheckSshCommandWithRetryE checks that you can connect via SSH to the given host and run the given command until max retries has been exceeded. +// It return an error if the command fails after max retries has been exceeded. + +func CheckSshCommandWithRetryE(t testing.TestingT, host Host, command string, retries int, sleepBetweenRetries time.Duration, f ...func(testing.TestingT, Host, string) (string, error)) (string, error) { + handler := CheckSshCommandE + if f != nil { + handler = f[0] + } + return retry.DoWithRetryE(t, fmt.Sprintf("Checking SSH connection to %s", host.Hostname), retries, sleepBetweenRetries, func() (string, error) { + return handler(t, host, command) + }) +} + // CheckPrivateSshConnection attempts to connect to privateHost (which is not addressable from the Internet) via a // separate publicHost (which is addressable from the Internet) and then executes "command" on privateHost and returns // its output. It is useful for checking that it's possible to SSH from a Bastion Host to a private instance. diff --git a/modules/ssh/ssh_test.go b/modules/ssh/ssh_test.go index 6f2ddd3c5..bdf41ca03 100644 --- a/modules/ssh/ssh_test.go +++ b/modules/ssh/ssh_test.go @@ -26,26 +26,87 @@ func TestHostWithCustomPort(t *testing.T) { assert.Equal(t, customPort, host.getPort(), "host.getPort() did not return the custom port number") } -func TestCheckSSHConnectionWithRetryE(t *testing.T) { +// global var for use in mock callback +var timesCalled int + +func TestCheckSshConnectionWithRetryE(t *testing.T) { + // Reset the global call count timesCalled = 0 + host := Host{Hostname: "Host"} - assert.Nil(t, CheckSshConnectionWithRetryE(t, host, 10, 3, mockSshConnectionE)) + retries := 10 + + assert.Nil(t, CheckSshConnectionWithRetryE(t, host, retries, 3, mockSshConnectionE)) +} + +func TestCheckSshConnectionWithRetryEExceedsMaxRetries(t *testing.T) { + // Reset the global call count + timesCalled = 0 + + host := Host{Hostname: "Host"} + + // Not enough retries + retries := 3 + + assert.Error(t, CheckSshConnectionWithRetryE(t, host, retries, 3, mockSshConnectionE)) } func TestCheckSshConnectionWithRetry(t *testing.T) { + // Reset the global call count timesCalled = 0 + host := Host{Hostname: "Host"} - CheckSshConnectionWithRetry(t, host, 10, 3, mockSshConnectionE) + retries := 10 + + CheckSshConnectionWithRetry(t, host, retries, 3, mockSshConnectionE) } -var timesCalled int +func TestCheckSshCommandWithRetryE(t *testing.T) { + // Reset the global call count + timesCalled = 0 + + host := Host{Hostname: "Host"} + command := "echo -n hello world" + retries := 10 + + _, err := CheckSshCommandWithRetryE(t, host, command, retries, 3, mockSshCommandE) + assert.Nil(t, err) +} + +func TestCheckSshCommandWithRetryEExceedsRetries(t *testing.T) { + // Reset the global call count + timesCalled = 0 + + host := Host{Hostname: "Host"} + command := "echo -n hello world" + + // Not enough retries + retries := 3 + + _, err := CheckSshCommandWithRetryE(t, host, command, retries, 3, mockSshCommandE) + assert.Error(t, err) +} + +func TestCheckSshCommandWithRetry(t *testing.T) { + // Reset the global call count + timesCalled = 0 + + host := Host{Hostname: "Host"} + command := "echo -n hello world" + retries := 10 + + CheckSshCommandWithRetry(t, host, command, retries, 3, mockSshCommandE) +} func mockSshConnectionE(t grunttest.TestingT, host Host) error { timesCalled += 1 - fmt.Println() if timesCalled >= 5 { return nil } else { return errors.New(fmt.Sprintf("Called %v times", timesCalled)) } } + +func mockSshCommandE(t grunttest.TestingT, host Host, command string) (string, error) { + return "", mockSshConnectionE(t, host) +}