Skip to content

Commit

Permalink
When the provider assumes a given role, don't use the default profile… (
Browse files Browse the repository at this point in the history
#87)

* When the provider assumes a given role, don't use the default profile if the profile is not given, but allow aws-sdk-go to find the credentials using the default credential provider chain (#86)

Signed-off-by: Massimo Battestini <[email protected]>

* Adds unit tests for AWS profile change (#86)

Signed-off-by: Massimo Battestini <[email protected]>

---------

Signed-off-by: Massimo Battestini <[email protected]>
  • Loading branch information
massimob76 authored Nov 1, 2023
1 parent b33b9cb commit 0cfc9f2
Show file tree
Hide file tree
Showing 5 changed files with 217 additions and 58 deletions.
19 changes: 9 additions & 10 deletions provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -396,11 +396,9 @@ func getClient(conf *ProviderConf) (*elastic7.Client, error) {
return client, nil
}

func assumeRoleCredentials(region, roleARN, roleExternalID, profile string) *awscredentials.Credentials {
sessOpts := awsSessionOptions(region)
if profile == "" {
sessOpts.Profile = "default"
} else {
func assumeRoleCredentials(region, roleARN, roleExternalID, profile string, endpoint string) *awscredentials.Credentials {
sessOpts := awsSessionOptions(region, endpoint)
if profile != "" {
sessOpts.Profile = profile
}

Expand All @@ -418,7 +416,7 @@ func assumeRoleCredentials(region, roleARN, roleExternalID, profile string) *aws
return awscredentials.NewChainCredentials([]awscredentials.Provider{assumeRoleProvider})
}

func awsSessionOptions(region string) awssession.Options {
func awsSessionOptions(region string, endpoint string) awssession.Options {
return awssession.Options{
Config: aws.Config{
Region: aws.String(region),
Expand All @@ -433,13 +431,14 @@ func awsSessionOptions(region string) awssession.Options {
// it fail with Credential error
// https://github.com/aws/aws-sdk-go/issues/2914
HTTPClient: &http.Client{Timeout: 10 * time.Second},
Endpoint: aws.String(endpoint),
},
SharedConfigState: awssession.SharedConfigEnable,
}
}

func awsSession(region string, conf *ProviderConf) *awssession.Session {
sessOpts := awsSessionOptions(region)
func awsSession(region string, conf *ProviderConf, endpoint string) *awssession.Session {
sessOpts := awsSessionOptions(region, endpoint)

// 1. access keys take priority
// 2. next is an assume role configuration
Expand All @@ -453,7 +452,7 @@ func awsSession(region string, conf *ProviderConf) *awssession.Session {
if conf.awsAssumeRoleExternalID == "" {
conf.awsAssumeRoleExternalID = ""
}
sessOpts.Config.Credentials = assumeRoleCredentials(region, conf.awsAssumeRoleArn, conf.awsAssumeRoleExternalID, conf.awsProfile)
sessOpts.Config.Credentials = assumeRoleCredentials(region, conf.awsAssumeRoleArn, conf.awsAssumeRoleExternalID, conf.awsProfile, endpoint)
} else if conf.awsProfile != "" {
sessOpts.Profile = conf.awsProfile
}
Expand All @@ -476,7 +475,7 @@ func awsSession(region string, conf *ProviderConf) *awssession.Session {
}

func awsHttpClient(region string, conf *ProviderConf, headers map[string]string) (*http.Client, error) {
session := awsSession(region, conf)
session := awsSession(region, conf, "")
// Call Get() to ensure concurrency safe retrieval of credentials. Since the
// client is created in many go routines, this synchronizes it.
_, err := session.Config.Credentials.Get()
Expand Down
221 changes: 176 additions & 45 deletions provider/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ package provider

import (
"context"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"

"github.com/aws/aws-sdk-go/aws/credentials"
Expand Down Expand Up @@ -82,13 +85,13 @@ func TestAWSCredsManualKey(t *testing.T) {
os.Setenv("AWS_SECRET_ACCESS_KEY", "ENV_SECRET")

// first, check that if we set aws_profile with aws_access_key_id - the latter takes precedence
testConfig := map[string]interface{}{
"aws_profile": namedProfile,
"aws_access_key": manualAccessKeyID,
"aws_secret_key": "MANUAL_SECRET_KEY",
testConfig := &ProviderConf{
awsAccessKeyId: manualAccessKeyID,
awsSecretAccessKey: "MANUAL_SECRET_KEY",
awsProfile: namedProfile,
}

creds := getCreds(t, testRegion, testConfig)
creds := getCreds(t, testRegion, testConfig, "")

if creds.AccessKeyID != manualAccessKeyID {
t.Errorf("access key id should have been %s (we got %s)", manualAccessKeyID, creds.AccessKeyID)
Expand All @@ -106,24 +109,24 @@ func TestAWSCredsNamedProfile(t *testing.T) {
namedProfile := "testing"
profileAccessKeyID := "PROFILE_ACCESS_KEY"

os.Setenv("AWS_CONFIG_FILE", "./test-fixtures/test_aws_config") // set config file so we can ensure the profile we want to test exists
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "./test-fixtures/test_aws_credentials") // set credentials file so we can ensure the profile we want to test exists
os.Setenv("AWS_SDK_LOAD_CONFIG", "1")
os.Setenv("AWS_ACCESS_KEY_ID", envAccessKeyID)
os.Setenv("AWS_SECRET_ACCESS_KEY", "ENV_SECRET")

testConfig := map[string]interface{}{
"aws_profile": namedProfile,
testConfig := &ProviderConf{
awsProfile: namedProfile,
}

creds := getCreds(t, testRegion, testConfig)
creds := getCreds(t, testRegion, testConfig, "")

if creds.AccessKeyID != profileAccessKeyID {
t.Errorf("access key id should have been %s (we got %s)", profileAccessKeyID, creds.AccessKeyID)
}

os.Unsetenv("AWS_ACCESS_KEY_ID")
os.Unsetenv("AWS_SECRET_ACCESS_KEY")
os.Unsetenv("AWS_CONFIG_FILE")
os.Unsetenv("AWS_SHARED_CREDENTIALS_FILE")
os.Unsetenv("AWS_SDK_LOAD_CONFIG")
}

Expand All @@ -132,17 +135,16 @@ func TestAWSCredsNamedProfile(t *testing.T) {
// 2. No configuration provided to the provider
//
// This tests that: we get the credentials from the environment variables (ie: from the default credentials provider chain)

func TestAWSCredsEnv(t *testing.T) {
envAccessKeyID := "ENV_ACCESS_KEY"
testRegion := "us-east-1"

os.Setenv("AWS_ACCESS_KEY_ID", envAccessKeyID)
os.Setenv("AWS_SECRET_ACCESS_KEY", "ENV_SECRET")

testConfig := map[string]interface{}{}
testConfig := &ProviderConf{}

creds := getCreds(t, testRegion, testConfig)
creds := getCreds(t, testRegion, testConfig, "")

if creds.AccessKeyID != envAccessKeyID {
t.Errorf("access key id should have been %s (we got %s)", envAccessKeyID, creds.AccessKeyID)
Expand All @@ -152,72 +154,152 @@ func TestAWSCredsEnv(t *testing.T) {
os.Unsetenv("AWS_SECRET_ACCESS_KEY")
}

// Given:
// 1. AWS profile is specified via environment variables
// 2. No configuration provided to the provider
//
// This tests that: we get the credentials from the environment variables (ie: from the default credentials provider chain)
func TestAWSCredsEnvNamedProfile(t *testing.T) {
namedProfile := "testing"
testRegion := "us-east-1"
profileAccessKeyID := "PROFILE_ACCESS_KEY"

os.Setenv("AWS_PROFILE", namedProfile)
os.Setenv("AWS_SDK_LOAD_CONFIG", "1")
os.Setenv("AWS_CONFIG_FILE", "./test-fixtures/test_aws_config") // set config file so we can ensure the profile we want to test exists
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "./test-fixtures/test_aws_credentials") // set credentials file so we can ensure the profile we want to test exists

testConfig := map[string]interface{}{}
testConfig := &ProviderConf{}

creds := getCreds(t, testRegion, testConfig)
creds := getCreds(t, testRegion, testConfig, "")

if creds.AccessKeyID != profileAccessKeyID {
t.Errorf("access key id should have been %s (we got %s)", profileAccessKeyID, creds.AccessKeyID)
}
os.Unsetenv("AWS_PROFILE")
os.Unsetenv("AWS_CONFIG_FILE")
os.Unsetenv("AWS_SHARED_CREDENTIALS_FILE")
os.Unsetenv("AWS_SDK_LOAD_CONFIG")
}

// Given:
// 1. An AWS role ARN is specified
// 2. No additional AWS configuration is provided to the provider
// 1. AWS credentials are specified via environment variables
// 2. An AWS role ARN and External ID are specified via the provider configuration
//
// This tests that: we can safely generate a session. Note we cannot get the credentials, because that requires connecting to AWS
// This tests that: we can get the credentials after having assumed the given role from the specified AWS credentials.
func TestAWSCredsAssumeRole(t *testing.T) {
envAccessKeyID := "ENV_ACCESS_KEY"
testRegion := "us-east-1"
assumeRoleArn := "arn:aws:iam::123456789012:role/demo/TestAR"
assumeRoleExternalId := "secret_id"
assumeRoleAccessKeyID := "ASIAIOSFODNN7EXAMPLE"

os.Setenv("AWS_ACCESS_KEY_ID", envAccessKeyID)
os.Setenv("AWS_SECRET_ACCESS_KEY", "ENV_SECRET")

testConfig := map[string]interface{}{
"aws_assume_role_arn": "test_arn",
"aws_assume_role_external_id": "secret_id",
server := mockServer{
ResponseFixturePath: "./test-fixtures/api_assume_role_response.xml",
ExpectedAccessKeyId: envAccessKeyID,
ExpectedRoleArn: assumeRoleArn,
ExpectedExternalId: assumeRoleExternalId,
}

testConfigData := schema.TestResourceDataRaw(t, Provider().Schema, testConfig)
server.Start(t)
defer server.Stop()

conf := &ProviderConf{
awsAssumeRoleArn: testConfigData.Get("aws_assume_role_arn").(string),
awsAssumeRoleExternalID: testConfigData.Get("aws_assume_role_external_id").(string),
testConfig := &ProviderConf{
awsAssumeRoleArn: assumeRoleArn,
awsAssumeRoleExternalID: assumeRoleExternalId,
}
s := awsSession(testRegion, conf)
if s == nil {
t.Fatalf("awsSession returned nil")

creds := getCreds(t, testRegion, testConfig, server.Endpoint)

if creds.AccessKeyID != assumeRoleAccessKeyID {
t.Errorf("access key id should have been %s (we got %s)", assumeRoleAccessKeyID, creds.AccessKeyID)
}

os.Unsetenv("AWS_ACCESS_KEY_ID")
os.Unsetenv("AWS_SECRET_ACCESS_KEY")
}

func getCreds(t *testing.T, region string, config map[string]interface{}) credentials.Value {
awsAccessKey := ""
awsSecretKey := ""
awsProfile := ""
if val, ok := config["aws_access_key"]; ok {
awsAccessKey = val.(string)
// Given:
// 1. An AWS profile, role ARN and External ID are specified via the provider configuration
//
// This tests that: we can get the credentials after having assumed the given role from the specified profile.
func TestAWSCredsAssumeRoleFromProfile(t *testing.T) {
testRegion := "us-east-1"
assumeRoleArn := "arn:aws:iam::123456789012:role/demo/TestAR"
assumeRoleExternalId := "secret_id"
namedProfile := "testing"
assumeRoleAccessKeyID := "ASIAIOSFODNN7EXAMPLE"

os.Setenv("AWS_SDK_LOAD_CONFIG", "1")
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "./test-fixtures/test_aws_credentials") // set credentials file so we can ensure the profile we want to test exists

server := mockServer{
ResponseFixturePath: "./test-fixtures/api_assume_role_response.xml",
ExpectedAccessKeyId: "PROFILE_ACCESS_KEY", // from the test-fixture config file
ExpectedRoleArn: assumeRoleArn,
ExpectedExternalId: assumeRoleExternalId,
}
if val, ok := config["aws_secret_key"]; ok {
awsSecretKey = val.(string)

server.Start(t)
defer server.Stop()

testConfig := &ProviderConf{
awsAssumeRoleArn: assumeRoleArn,
awsAssumeRoleExternalID: assumeRoleExternalId,
awsProfile: namedProfile,
}
if val, ok := config["aws_profile"]; ok {
awsProfile = val.(string)

creds := getCreds(t, testRegion, testConfig, server.Endpoint)

if creds.AccessKeyID != assumeRoleAccessKeyID {
t.Errorf("access key id should have been %s (we got %s)", assumeRoleAccessKeyID, creds.AccessKeyID)
}

conf := &ProviderConf{
awsAccessKeyId: awsAccessKey,
awsSecretAccessKey: awsSecretKey,
awsProfile: awsProfile,
os.Unsetenv("AWS_SDK_LOAD_CONFIG")
os.Unsetenv("AWS_SHARED_CREDENTIALS_FILE")
}

// Given:
// 1. An AWS role ARN and External ID are specified via the provider configuration
//
// This tests that: we can get the credentials after having assumed the given role from the default profile.
func TestAWSCredsAssumeRoleFromDefaultProfile(t *testing.T) {
testRegion := "us-east-1"
assumeRoleArn := "arn:aws:iam::123456789012:role/demo/TestAR"
assumeRoleExternalId := "secret_id"
assumeRoleAccessKeyID := "ASIAIOSFODNN7EXAMPLE"

os.Setenv("AWS_SDK_LOAD_CONFIG", "1")
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "./test-fixtures/test_aws_credentials") // set credentials file so we can ensure the profile we want to test exists

server := mockServer{
ResponseFixturePath: "./test-fixtures/api_assume_role_response.xml",
ExpectedAccessKeyId: "PROFILE_DEFAULT_ACCESS_KEY", // from the test-fixture config file
ExpectedRoleArn: assumeRoleArn,
ExpectedExternalId: assumeRoleExternalId,
}

server.Start(t)
defer server.Stop()

testConfig := &ProviderConf{
awsAssumeRoleArn: assumeRoleArn,
awsAssumeRoleExternalID: assumeRoleExternalId,
}
s := awsSession(region, conf)

creds := getCreds(t, testRegion, testConfig, server.Endpoint)

if creds.AccessKeyID != assumeRoleAccessKeyID {
t.Errorf("access key id should have been %s (we got %s)", assumeRoleAccessKeyID, creds.AccessKeyID)
}

os.Unsetenv("AWS_SDK_LOAD_CONFIG")
os.Unsetenv("AWS_SHARED_CREDENTIALS_FILE")
}

func getCreds(t *testing.T, region string, config *ProviderConf, endpoint string) credentials.Value {
s := awsSession(region, config, endpoint)
if s == nil {
t.Fatalf("awsSession returned nil")
}
Expand All @@ -227,3 +309,52 @@ func getCreds(t *testing.T, region string, config map[string]interface{}) creden
}
return creds
}

type mockServer struct {
ResponseFixturePath string
ExpectedAccessKeyId string
ExpectedRoleArn string
ExpectedExternalId string
Endpoint string
server *httptest.Server
}

func (s *mockServer) Start(t *testing.T) {
s.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

auth := r.Header.Get("Authorization")
if !strings.Contains(auth, s.ExpectedAccessKeyId) {
t.Errorf("Could not find expected access key id %s in authorization header %s", s.ExpectedAccessKeyId, auth)
}

err := r.ParseForm()
if err != nil {
t.Errorf("Error while parsing form: %v", err)
}

if r.PostForm.Get("RoleArn") != s.ExpectedRoleArn {
t.Errorf("expected RoleArn to be equal to %s, but got %s", s.ExpectedRoleArn, r.PostForm.Get("RoleArn"))
}

if r.PostForm.Get("ExternalId") != s.ExpectedExternalId {
t.Errorf("expected ExternalId to be equal to %s, but got %s", s.ExpectedExternalId, r.PostForm.Get("ExternalId"))
}

response, err := os.ReadFile(s.ResponseFixturePath)
if err != nil {
t.Errorf("Error while reading mockResponse %v", err)
}

w.WriteHeader(http.StatusOK)
_, err = w.Write(response)
if err != nil {
t.Errorf("Error while writing mock server response %v", err)
}
}))

s.Endpoint = s.server.URL
}

func (s *mockServer) Stop() {
s.server.Close()
}
Loading

0 comments on commit 0cfc9f2

Please sign in to comment.