Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

When the provider assumes a given role, don't use the default profile… #87

Merged
merged 2 commits into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -395,11 +395,9 @@ func getClient(conf *ProviderConf) (*elastic7.Client, error) {
return client, nil
}

func assumeRoleCredentials(region, roleARN, roleExternalID, profile string) *awscredentials.Credentials {
sessOpts := awsSessionOptions(region)
if profile == "" {
sessOpts.Profile = "default"
} else {
func assumeRoleCredentials(region, roleARN, roleExternalID, profile string, endpoint string) *awscredentials.Credentials {
sessOpts := awsSessionOptions(region, endpoint)
if profile != "" {
sessOpts.Profile = profile
}

Expand All @@ -417,7 +415,7 @@ func assumeRoleCredentials(region, roleARN, roleExternalID, profile string) *aws
return awscredentials.NewChainCredentials([]awscredentials.Provider{assumeRoleProvider})
}

func awsSessionOptions(region string) awssession.Options {
func awsSessionOptions(region string, endpoint string) awssession.Options {
return awssession.Options{
Config: aws.Config{
Region: aws.String(region),
Expand All @@ -432,13 +430,14 @@ func awsSessionOptions(region string) awssession.Options {
// it fail with Credential error
// https://github.com/aws/aws-sdk-go/issues/2914
HTTPClient: &http.Client{Timeout: 10 * time.Second},
Endpoint: aws.String(endpoint),
},
SharedConfigState: awssession.SharedConfigEnable,
}
}

func awsSession(region string, conf *ProviderConf) *awssession.Session {
sessOpts := awsSessionOptions(region)
func awsSession(region string, conf *ProviderConf, endpoint string) *awssession.Session {
sessOpts := awsSessionOptions(region, endpoint)

// 1. access keys take priority
// 2. next is an assume role configuration
Expand All @@ -452,7 +451,7 @@ func awsSession(region string, conf *ProviderConf) *awssession.Session {
if conf.awsAssumeRoleExternalID == "" {
conf.awsAssumeRoleExternalID = ""
}
sessOpts.Config.Credentials = assumeRoleCredentials(region, conf.awsAssumeRoleArn, conf.awsAssumeRoleExternalID, conf.awsProfile)
sessOpts.Config.Credentials = assumeRoleCredentials(region, conf.awsAssumeRoleArn, conf.awsAssumeRoleExternalID, conf.awsProfile, endpoint)
} else if conf.awsProfile != "" {
sessOpts.Profile = conf.awsProfile
}
Expand All @@ -475,7 +474,7 @@ func awsSession(region string, conf *ProviderConf) *awssession.Session {
}

func awsHttpClient(region string, conf *ProviderConf, headers map[string]string) (*http.Client, error) {
session := awsSession(region, conf)
session := awsSession(region, conf, "")
// Call Get() to ensure concurrency safe retrieval of credentials. Since the
// client is created in many go routines, this synchronizes it.
_, err := session.Config.Credentials.Get()
Expand Down
221 changes: 176 additions & 45 deletions provider/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ package provider

import (
"context"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"

"github.com/aws/aws-sdk-go/aws/credentials"
Expand Down Expand Up @@ -82,13 +85,13 @@ func TestAWSCredsManualKey(t *testing.T) {
os.Setenv("AWS_SECRET_ACCESS_KEY", "ENV_SECRET")

// first, check that if we set aws_profile with aws_access_key_id - the latter takes precedence
testConfig := map[string]interface{}{
"aws_profile": namedProfile,
"aws_access_key": manualAccessKeyID,
"aws_secret_key": "MANUAL_SECRET_KEY",
testConfig := &ProviderConf{
awsAccessKeyId: manualAccessKeyID,
awsSecretAccessKey: "MANUAL_SECRET_KEY",
awsProfile: namedProfile,
}

creds := getCreds(t, testRegion, testConfig)
creds := getCreds(t, testRegion, testConfig, "")

if creds.AccessKeyID != manualAccessKeyID {
t.Errorf("access key id should have been %s (we got %s)", manualAccessKeyID, creds.AccessKeyID)
Expand All @@ -106,24 +109,24 @@ func TestAWSCredsNamedProfile(t *testing.T) {
namedProfile := "testing"
profileAccessKeyID := "PROFILE_ACCESS_KEY"

os.Setenv("AWS_CONFIG_FILE", "./test-fixtures/test_aws_config") // set config file so we can ensure the profile we want to test exists
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "./test-fixtures/test_aws_credentials") // set credentials file so we can ensure the profile we want to test exists
os.Setenv("AWS_SDK_LOAD_CONFIG", "1")
os.Setenv("AWS_ACCESS_KEY_ID", envAccessKeyID)
os.Setenv("AWS_SECRET_ACCESS_KEY", "ENV_SECRET")

testConfig := map[string]interface{}{
"aws_profile": namedProfile,
testConfig := &ProviderConf{
awsProfile: namedProfile,
}

creds := getCreds(t, testRegion, testConfig)
creds := getCreds(t, testRegion, testConfig, "")

if creds.AccessKeyID != profileAccessKeyID {
t.Errorf("access key id should have been %s (we got %s)", profileAccessKeyID, creds.AccessKeyID)
}

os.Unsetenv("AWS_ACCESS_KEY_ID")
os.Unsetenv("AWS_SECRET_ACCESS_KEY")
os.Unsetenv("AWS_CONFIG_FILE")
os.Unsetenv("AWS_SHARED_CREDENTIALS_FILE")
os.Unsetenv("AWS_SDK_LOAD_CONFIG")
}

Expand All @@ -132,17 +135,16 @@ func TestAWSCredsNamedProfile(t *testing.T) {
// 2. No configuration provided to the provider
//
// This tests that: we get the credentials from the environment variables (ie: from the default credentials provider chain)

func TestAWSCredsEnv(t *testing.T) {
envAccessKeyID := "ENV_ACCESS_KEY"
testRegion := "us-east-1"

os.Setenv("AWS_ACCESS_KEY_ID", envAccessKeyID)
os.Setenv("AWS_SECRET_ACCESS_KEY", "ENV_SECRET")

testConfig := map[string]interface{}{}
testConfig := &ProviderConf{}

creds := getCreds(t, testRegion, testConfig)
creds := getCreds(t, testRegion, testConfig, "")

if creds.AccessKeyID != envAccessKeyID {
t.Errorf("access key id should have been %s (we got %s)", envAccessKeyID, creds.AccessKeyID)
Expand All @@ -152,72 +154,152 @@ func TestAWSCredsEnv(t *testing.T) {
os.Unsetenv("AWS_SECRET_ACCESS_KEY")
}

// Given:
// 1. AWS profile is specified via environment variables
// 2. No configuration provided to the provider
//
// This tests that: we get the credentials from the environment variables (ie: from the default credentials provider chain)
func TestAWSCredsEnvNamedProfile(t *testing.T) {
namedProfile := "testing"
testRegion := "us-east-1"
profileAccessKeyID := "PROFILE_ACCESS_KEY"

os.Setenv("AWS_PROFILE", namedProfile)
os.Setenv("AWS_SDK_LOAD_CONFIG", "1")
os.Setenv("AWS_CONFIG_FILE", "./test-fixtures/test_aws_config") // set config file so we can ensure the profile we want to test exists
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "./test-fixtures/test_aws_credentials") // set credentials file so we can ensure the profile we want to test exists

testConfig := map[string]interface{}{}
testConfig := &ProviderConf{}

creds := getCreds(t, testRegion, testConfig)
creds := getCreds(t, testRegion, testConfig, "")

if creds.AccessKeyID != profileAccessKeyID {
t.Errorf("access key id should have been %s (we got %s)", profileAccessKeyID, creds.AccessKeyID)
}
os.Unsetenv("AWS_PROFILE")
os.Unsetenv("AWS_CONFIG_FILE")
os.Unsetenv("AWS_SHARED_CREDENTIALS_FILE")
os.Unsetenv("AWS_SDK_LOAD_CONFIG")
}

// Given:
// 1. An AWS role ARN is specified
// 2. No additional AWS configuration is provided to the provider
// 1. AWS credentials are specified via environment variables
// 2. An AWS role ARN and External ID are specified via the provider configuration
//
// This tests that: we can safely generate a session. Note we cannot get the credentials, because that requires connecting to AWS
// This tests that: we can get the credentials after having assumed the given role from the specified AWS credentials.
func TestAWSCredsAssumeRole(t *testing.T) {
envAccessKeyID := "ENV_ACCESS_KEY"
testRegion := "us-east-1"
assumeRoleArn := "arn:aws:iam::123456789012:role/demo/TestAR"
assumeRoleExternalId := "secret_id"
assumeRoleAccessKeyID := "ASIAIOSFODNN7EXAMPLE"

os.Setenv("AWS_ACCESS_KEY_ID", envAccessKeyID)
os.Setenv("AWS_SECRET_ACCESS_KEY", "ENV_SECRET")

testConfig := map[string]interface{}{
"aws_assume_role_arn": "test_arn",
"aws_assume_role_external_id": "secret_id",
server := mockServer{
ResponseFixturePath: "./test-fixtures/api_assume_role_response.xml",
ExpectedAccessKeyId: envAccessKeyID,
ExpectedRoleArn: assumeRoleArn,
ExpectedExternalId: assumeRoleExternalId,
}

testConfigData := schema.TestResourceDataRaw(t, Provider().Schema, testConfig)
server.Start(t)
defer server.Stop()

conf := &ProviderConf{
awsAssumeRoleArn: testConfigData.Get("aws_assume_role_arn").(string),
awsAssumeRoleExternalID: testConfigData.Get("aws_assume_role_external_id").(string),
testConfig := &ProviderConf{
awsAssumeRoleArn: assumeRoleArn,
awsAssumeRoleExternalID: assumeRoleExternalId,
}
s := awsSession(testRegion, conf)
if s == nil {
t.Fatalf("awsSession returned nil")

creds := getCreds(t, testRegion, testConfig, server.Endpoint)

if creds.AccessKeyID != assumeRoleAccessKeyID {
t.Errorf("access key id should have been %s (we got %s)", assumeRoleAccessKeyID, creds.AccessKeyID)
}

os.Unsetenv("AWS_ACCESS_KEY_ID")
os.Unsetenv("AWS_SECRET_ACCESS_KEY")
}

func getCreds(t *testing.T, region string, config map[string]interface{}) credentials.Value {
awsAccessKey := ""
awsSecretKey := ""
awsProfile := ""
if val, ok := config["aws_access_key"]; ok {
awsAccessKey = val.(string)
// Given:
// 1. An AWS profile, role ARN and External ID are specified via the provider configuration
//
// This tests that: we can get the credentials after having assumed the given role from the specified profile.
func TestAWSCredsAssumeRoleFromProfile(t *testing.T) {
testRegion := "us-east-1"
assumeRoleArn := "arn:aws:iam::123456789012:role/demo/TestAR"
assumeRoleExternalId := "secret_id"
namedProfile := "testing"
assumeRoleAccessKeyID := "ASIAIOSFODNN7EXAMPLE"

os.Setenv("AWS_SDK_LOAD_CONFIG", "1")
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "./test-fixtures/test_aws_credentials") // set credentials file so we can ensure the profile we want to test exists

server := mockServer{
ResponseFixturePath: "./test-fixtures/api_assume_role_response.xml",
ExpectedAccessKeyId: "PROFILE_ACCESS_KEY", // from the test-fixture config file
ExpectedRoleArn: assumeRoleArn,
ExpectedExternalId: assumeRoleExternalId,
}
if val, ok := config["aws_secret_key"]; ok {
awsSecretKey = val.(string)

server.Start(t)
defer server.Stop()

testConfig := &ProviderConf{
awsAssumeRoleArn: assumeRoleArn,
awsAssumeRoleExternalID: assumeRoleExternalId,
awsProfile: namedProfile,
}
if val, ok := config["aws_profile"]; ok {
awsProfile = val.(string)

creds := getCreds(t, testRegion, testConfig, server.Endpoint)

if creds.AccessKeyID != assumeRoleAccessKeyID {
t.Errorf("access key id should have been %s (we got %s)", assumeRoleAccessKeyID, creds.AccessKeyID)
}

conf := &ProviderConf{
awsAccessKeyId: awsAccessKey,
awsSecretAccessKey: awsSecretKey,
awsProfile: awsProfile,
os.Unsetenv("AWS_SDK_LOAD_CONFIG")
os.Unsetenv("AWS_SHARED_CREDENTIALS_FILE")
}

// Given:
// 1. An AWS role ARN and External ID are specified via the provider configuration
//
// This tests that: we can get the credentials after having assumed the given role from the default profile.
func TestAWSCredsAssumeRoleFromDefaultProfile(t *testing.T) {
testRegion := "us-east-1"
assumeRoleArn := "arn:aws:iam::123456789012:role/demo/TestAR"
assumeRoleExternalId := "secret_id"
assumeRoleAccessKeyID := "ASIAIOSFODNN7EXAMPLE"

os.Setenv("AWS_SDK_LOAD_CONFIG", "1")
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "./test-fixtures/test_aws_credentials") // set credentials file so we can ensure the profile we want to test exists

server := mockServer{
ResponseFixturePath: "./test-fixtures/api_assume_role_response.xml",
ExpectedAccessKeyId: "PROFILE_DEFAULT_ACCESS_KEY", // from the test-fixture config file
ExpectedRoleArn: assumeRoleArn,
ExpectedExternalId: assumeRoleExternalId,
}

server.Start(t)
defer server.Stop()

testConfig := &ProviderConf{
awsAssumeRoleArn: assumeRoleArn,
awsAssumeRoleExternalID: assumeRoleExternalId,
}
s := awsSession(region, conf)

creds := getCreds(t, testRegion, testConfig, server.Endpoint)

if creds.AccessKeyID != assumeRoleAccessKeyID {
t.Errorf("access key id should have been %s (we got %s)", assumeRoleAccessKeyID, creds.AccessKeyID)
}

os.Unsetenv("AWS_SDK_LOAD_CONFIG")
os.Unsetenv("AWS_SHARED_CREDENTIALS_FILE")
}

func getCreds(t *testing.T, region string, config *ProviderConf, endpoint string) credentials.Value {
s := awsSession(region, config, endpoint)
if s == nil {
t.Fatalf("awsSession returned nil")
}
Expand All @@ -227,3 +309,52 @@ func getCreds(t *testing.T, region string, config map[string]interface{}) creden
}
return creds
}

type mockServer struct {
ResponseFixturePath string
ExpectedAccessKeyId string
ExpectedRoleArn string
ExpectedExternalId string
Endpoint string
server *httptest.Server
}

func (s *mockServer) Start(t *testing.T) {
s.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

auth := r.Header.Get("Authorization")
if !strings.Contains(auth, s.ExpectedAccessKeyId) {
t.Errorf("Could not find expected access key id %s in authorization header %s", s.ExpectedAccessKeyId, auth)
}

err := r.ParseForm()
if err != nil {
t.Errorf("Error while parsing form: %v", err)
}

if r.PostForm.Get("RoleArn") != s.ExpectedRoleArn {
t.Errorf("expected RoleArn to be equal to %s, but got %s", s.ExpectedRoleArn, r.PostForm.Get("RoleArn"))
}

if r.PostForm.Get("ExternalId") != s.ExpectedExternalId {
t.Errorf("expected ExternalId to be equal to %s, but got %s", s.ExpectedExternalId, r.PostForm.Get("ExternalId"))
}

response, err := os.ReadFile(s.ResponseFixturePath)
if err != nil {
t.Errorf("Error while reading mockResponse %v", err)
}

w.WriteHeader(http.StatusOK)
_, err = w.Write(response)
if err != nil {
t.Errorf("Error while writing mock server response %v", err)
}
}))

s.Endpoint = s.server.URL
}

func (s *mockServer) Stop() {
s.server.Close()
}
Loading
Loading