diff --git a/modules/aws/ssm.go b/modules/aws/ssm.go index 1a66fb3ab..364841868 100644 --- a/modules/aws/ssm.go +++ b/modules/aws/ssm.go @@ -26,7 +26,10 @@ func GetParameterE(t testing.TestingT, awsRegion string, keyName string) (string return "", err } - resp, err := ssmClient.GetParameter(&ssm.GetParameterInput{Name: aws.String(keyName), WithDecryption: aws.Bool(true)}) + return GetParameterWithClientE(t, ssmClient, keyName) +} +func GetParameterWithClientE(t testing.TestingT, client *ssm.SSM, keyName string) (string, error) { + resp, err := client.GetParameter(&ssm.GetParameterInput{Name: aws.String(keyName), WithDecryption: aws.Bool(true)}) if err != nil { return "", err } @@ -48,8 +51,10 @@ func PutParameterE(t testing.TestingT, awsRegion string, keyName string, keyDesc if err != nil { return 0, err } - - resp, err := ssmClient.PutParameter(&ssm.PutParameterInput{Name: aws.String(keyName), Description: aws.String(keyDescription), Value: aws.String(keyValue), Type: aws.String("SecureString")}) + return PutParameterWithClientE(t, ssmClient, keyName, keyDescription, keyValue) +} +func PutParameterWithClientE(t testing.TestingT, client *ssm.SSM, keyName string, keyDescription string, keyValue string) (int64, error) { + resp, err := client.PutParameter(&ssm.PutParameterInput{Name: aws.String(keyName), Description: aws.String(keyDescription), Value: aws.String(keyValue), Type: aws.String("SecureString")}) if err != nil { return 0, err } @@ -69,8 +74,10 @@ func DeleteParameterE(t testing.TestingT, awsRegion string, keyName string) erro if err != nil { return err } - - _, err = ssmClient.DeleteParameter(&ssm.DeleteParameterInput{Name: aws.String(keyName)}) + return DeleteParameterWithClientE(t, ssmClient, keyName) +} +func DeleteParameterWithClientE(t testing.TestingT, client *ssm.SSM, keyName string) error { + _, err := client.DeleteParameter(&ssm.DeleteParameterInput{Name: aws.String(keyName)}) if err != nil { return err } @@ -97,6 +104,14 @@ func NewSsmClientE(t testing.TestingT, region string) (*ssm.SSM, error) { // WaitForSsmInstanceE waits until the instance get registered to the SSM inventory. func WaitForSsmInstanceE(t testing.TestingT, awsRegion, instanceID string, timeout time.Duration) error { + client, err := NewSsmClientE(t, awsRegion) + if err != nil { + return err + } + return WaitForSsmInstanceWithClientE(t, client, instanceID, timeout) +} + +func WaitForSsmInstanceWithClientE(t testing.TestingT, client *ssm.SSM, instanceID string, timeout time.Duration) error { timeBetweenRetries := 2 * time.Second maxRetries := int(timeout.Seconds() / timeBetweenRetries.Seconds()) description := fmt.Sprintf("Waiting for %s to appear in the SSM inventory", instanceID) @@ -111,7 +126,6 @@ func WaitForSsmInstanceE(t testing.TestingT, awsRegion, instanceID string, timeo }, } _, err := retry.DoWithRetryE(t, description, maxRetries, timeBetweenRetries, func() (string, error) { - client := NewSsmClient(t, awsRegion) resp, err := client.GetInventory(input) if err != nil { @@ -152,14 +166,19 @@ type CommandOutput struct { func CheckSsmCommandE(t testing.TestingT, awsRegion, instanceID, command string, timeout time.Duration) (*CommandOutput, error) { logger.Logf(t, "Running command '%s' on EC2 instance with ID '%s'", command, instanceID) - timeBetweenRetries := 2 * time.Second - maxRetries := int(timeout.Seconds() / timeBetweenRetries.Seconds()) - // Now that we know the instance in the SSM inventory, we can send the command client, err := NewSsmClientE(t, awsRegion) if err != nil { return nil, err } + return CheckSSMCommandWithClientE(t, client, instanceID, command, timeout) +} + +func CheckSSMCommandWithClientE(t testing.TestingT, client *ssm.SSM, instanceID, command string, timeout time.Duration) (*CommandOutput, error) { + + timeBetweenRetries := 2 * time.Second + maxRetries := int(timeout.Seconds() / timeBetweenRetries.Seconds()) + resp, err := client.SendCommand(&ssm.SendCommandInput{ Comment: aws.String("Terratest SSM"), DocumentName: aws.String("AWS-RunShellScript"),