Skip to content

Commit

Permalink
PutSecureParameter and recursive GetAllParametersByPath (#27)
Browse files Browse the repository at this point in the history
  • Loading branch information
ingshtrom authored Sep 20, 2020
1 parent 40470bb commit 457696c
Show file tree
Hide file tree
Showing 2 changed files with 211 additions and 9 deletions.
53 changes: 53 additions & 0 deletions parameter_store_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package awsssm

import (
"errors"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/session"
Expand All @@ -18,6 +19,7 @@ var (
type ssmClient interface {
GetParametersByPathPages(input *ssm.GetParametersByPathInput, fn func(*ssm.GetParametersByPathOutput, bool) bool) error
GetParameter(input *ssm.GetParameterInput) (*ssm.GetParameterOutput, error)
PutParameter(input *ssm.PutParameterInput) (*ssm.PutParameterOutput, error)
}

//ParameterStore holds all the methods tha are supported against AWS Parameter Store
Expand All @@ -30,6 +32,8 @@ type ParameterStore struct {
//Will return /my-service/dev/param-a, /my-service/dev/param-b, etc... but will not return recursive paths
//the `ssm:GetAllParametersByPath` permission is required
//to the `arn:aws:ssm:aws-region:aws-account-id:/my-service/dev/*`
//
//This will also page through and return all elements in the hierarchy, non-recursively
func (ps *ParameterStore) GetAllParametersByPath(path string, decrypt bool) (*Parameters, error) {
var input = &ssm.GetParametersByPathInput{}
input.SetWithDecryption(decrypt)
Expand Down Expand Up @@ -81,6 +85,55 @@ func (ps *ParameterStore) getParameter(input *ssm.GetParameterInput) (*Parameter
}, nil
}

//PutSecureParameter is setting the parameter with the given name to a passed in value.
//Allow overwriting the value of the parameter already exists, otherwise an error is returned
//For example a request with name as '/my-service/dev/param-1':
//Will set the parameter value if exists or ErrParameterInvalidName if parameter already exists or is empty
// and `overwrite` is false. The `ssm:PutParameter` permission is required to the
//`arn:aws:ssm:aws-region:aws-account-id:/my-service/dev/param-1` resource
func (ps *ParameterStore) PutSecureParameter(name, value string, overwrite bool) error {
return ps.putSecureParameterWrapper(name, value, "", overwrite)
}

//PutSecureParameterWithCMK is the same as PutSecureParameter but with a passed in CMK (Customer Master Key)
//For example a request with name as '/my-service/dev/param-1' and a `kmsID` of 'foo':
//Will set the parameter value if exists or ErrParameterInvalidName if parameter already exists or is empty
// and `overwrite` is false. The `ssm:PutParameter` permission is required to the
//`arn:aws:ssm:aws-region:aws-account-id:/my-service/dev/param-1` resource
// The `kms:Encrypt` permission is required to the `arn:aws:kms:us-east-1:710015040892:key/foo`
func (ps *ParameterStore) PutSecureParameterWithCMK(name, value string, overwrite bool, kmsID string) error {
return ps.putSecureParameterWrapper(name, value, kmsID, overwrite)
}
func (ps *ParameterStore) putSecureParameterWrapper(name, value, kmsID string, overwrite bool) error {
if name == "" {
return ErrParameterInvalidName
}
input := &ssm.PutParameterInput{}
input.SetName(name)
input.SetType("SecureString")
input.SetValue(value)
if kmsID != "" {
input.SetKeyId(kmsID)
}
input.SetOverwrite(overwrite)

if err := input.Validate(); err != nil {
return err
}

return ps.putParameter(input)
}
func (ps *ParameterStore) putParameter(input *ssm.PutParameterInput) error {
_, err := ps.ssm.PutParameter(input)
if err != nil {
if awsError, ok := err.(awserr.Error); ok && awsError.Code() == ssm.ErrCodeParameterAlreadyExists {
return ErrParameterInvalidName
}
return err
}
return nil
}

//NewParameterStoreWithClient is creating a new ParameterStore with the given ssm Client
func NewParameterStoreWithClient(client ssmClient) *ParameterStore {
return &ParameterStore{ssm: client}
Expand Down
167 changes: 158 additions & 9 deletions parameter_store_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ package awsssm

import (
"errors"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/service/ssm"
"reflect"
"testing"

"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/service/ssm"
)

var param1 = new(ssm.Parameter).
Expand All @@ -18,6 +19,7 @@ var param2 = new(ssm.Parameter).
SetValue("rds.something.aws.com").
SetARN("arn:aws:ssm:us-east-2:aws-account-id:/my-service/dev/DB_HOST")

// return s.GetParametersByPathOutput, s.GetParametersByPathError
var param3 = new(ssm.Parameter).
SetName("/my-service/dev/DB_USERNAME").
SetValue("username").
Expand All @@ -31,13 +33,14 @@ type stubGetParametersByPathOutput struct {
}

type stubSSMClient struct {
GetParametersByPathOutput []stubGetParametersByPathOutput
GetParametersByPathError error
GetParameterOutput *ssm.GetParameterOutput
GetParameterError error
GetParametersByPathOutput []stubGetParametersByPathOutput
GetParametersByPathError error
GetParameterOutput *ssm.GetParameterOutput
GetParameterError error
PutParameterInputReceived *ssm.PutParameterInput
}

func (s stubSSMClient) GetParametersByPathPages(input *ssm.GetParametersByPathInput, fn func(*ssm.GetParametersByPathOutput, bool) bool) error {
func (s *stubSSMClient) GetParametersByPathPages(input *ssm.GetParametersByPathInput, fn func(*ssm.GetParametersByPathOutput, bool) bool) error {
if s.GetParametersByPathError == nil {
for _, output := range s.GetParametersByPathOutput {
done := fn(&output.Output, output.MoreParamsLeft)
Expand All @@ -49,10 +52,17 @@ func (s stubSSMClient) GetParametersByPathPages(input *ssm.GetParametersByPathIn
return s.GetParametersByPathError
}

func (s stubSSMClient) GetParameter(input *ssm.GetParameterInput) (*ssm.GetParameterOutput, error) {
func (s *stubSSMClient) GetParameter(input *ssm.GetParameterInput) (*ssm.GetParameterOutput, error) {
return s.GetParameterOutput, s.GetParameterError
}

// we return nothing becuase the actual response is pretty boring. Just a version number. We DO
// want to track was is input because there is a _little_ business logic around that
func (s *stubSSMClient) PutParameter(input *ssm.PutParameterInput) (*ssm.PutParameterOutput, error) {
s.PutParameterInputReceived = input
return nil, nil
}

func TestClient_GetParametersByPath(t *testing.T) {
tests := []struct {
name string
Expand All @@ -71,6 +81,12 @@ func TestClient_GetParametersByPath(t *testing.T) {
Parameters: getParameters(),
},
},
{
MoreParamsLeft: true,
Output: ssm.GetParametersByPathOutput{
Parameters: getParameters2(),
},
},
{
MoreParamsLeft: false,
Output: ssm.GetParametersByPathOutput{
Expand Down Expand Up @@ -110,7 +126,7 @@ func TestClient_GetParametersByPath(t *testing.T) {
t.Errorf(`Unexpected error: got %d, expected %d`, err, test.expectedError)
}
if !reflect.DeepEqual(parameters, test.expectedOutput) {
t.Error(`Unexpected parameters`, *parameters, *test.expectedOutput)
t.Errorf(`Unexpected parameters: got: %+v, expected: %+v`, *parameters, *test.expectedOutput)
}
})
}
Expand All @@ -122,6 +138,12 @@ func getParameters() []*ssm.Parameter {
}
}

func getParameters2() []*ssm.Parameter {
return []*ssm.Parameter{
param3,
}
}

func TestParameterStore_GetParameter(t *testing.T) {
value := "something-secure"
tests := []struct {
Expand Down Expand Up @@ -172,3 +194,130 @@ func TestParameterStore_GetParameter(t *testing.T) {
})
}
}

func TestParameterStore_PutSecureParameter(t *testing.T) {
paramName := "foo"
paramValue := "baz"
paramType := "SecureString"
overwriteTrue := true
overwriteFalse := false

tests := []struct {
name string
ssmClient *stubSSMClient
parameterName string
parameterValue string
overwrite bool
expectedError error
expectedInput *ssm.PutParameterInput
}{
{
name: "Failed Empty name",
ssmClient: &stubSSMClient{},
parameterName: "",
parameterValue: "",
expectedError: ErrParameterInvalidName,
},
{
name: "Set Correct Defaults",
ssmClient: &stubSSMClient{},
parameterName: paramName,
parameterValue: paramValue,
expectedInput: &ssm.PutParameterInput{
Name: &paramName,
Type: &paramType,
Value: &paramValue,
Overwrite: &overwriteFalse,
},
},
{
name: "Overwrite Changes Propagate",
ssmClient: &stubSSMClient{},
parameterName: paramName,
parameterValue: paramValue,
overwrite: overwriteTrue,
expectedInput: &ssm.PutParameterInput{
Name: &paramName,
Type: &paramType,
Value: &paramValue,
Overwrite: &overwriteTrue,
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
client := NewParameterStoreWithClient(test.ssmClient)
err := client.PutSecureParameter(test.parameterName, test.parameterValue, test.overwrite)
if err != test.expectedError {
t.Errorf(`Unexpected error: got %d, expected %d`, err, test.expectedError)
}
if !reflect.DeepEqual(test.ssmClient.PutParameterInputReceived, test.expectedInput) {
t.Errorf(`Unexpected parameter: got %v, expected %v`, test.ssmClient.PutParameterInputReceived, test.expectedInput)
}
})
}
}

func TestParameterStore_PutSecureParameterWithCMK(t *testing.T) {
paramName := "foo"
paramValue := "baz"
paramType := "SecureString"
overwriteFalse := false
kmsID := "super-secret-kms"
tests := []struct {
name string
ssmClient *stubSSMClient
parameterName string
parameterValue string
overwrite bool
kmsID string
expectedError error
expectedInput *ssm.PutParameterInput
}{
{
name: "Failed Empty name",
ssmClient: &stubSSMClient{},
parameterName: "",
parameterValue: "",
expectedError: ErrParameterInvalidName,
},
{
name: "Set Correct Defaults",
ssmClient: &stubSSMClient{},
parameterName: paramName,
parameterValue: paramValue,
expectedInput: &ssm.PutParameterInput{
Name: &paramName,
Overwrite: &overwriteFalse,
Type: &paramType,
Value: &paramValue,
},
},
{
name: "KMS ID Changes Propagate",
ssmClient: &stubSSMClient{},
parameterName: paramName,
parameterValue: paramValue,
kmsID: kmsID,
expectedInput: &ssm.PutParameterInput{
KeyId: &kmsID,
Name: &paramName,
Overwrite: &overwriteFalse,
Type: &paramType,
Value: &paramValue,
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
client := NewParameterStoreWithClient(test.ssmClient)
err := client.PutSecureParameterWithCMK(test.parameterName, test.parameterValue, test.overwrite, test.kmsID)
if err != test.expectedError {
t.Errorf(`Unexpected error: got %d, expected %d`, err, test.expectedError)
}
if !reflect.DeepEqual(test.ssmClient.PutParameterInputReceived, test.expectedInput) {
t.Errorf(`Unexpected parameter: got %v, expected %v`, test.ssmClient.PutParameterInputReceived, test.expectedInput)
}
})
}
}

0 comments on commit 457696c

Please sign in to comment.