diff --git a/go.mod b/go.mod index 5d45d4ac..0a40be67 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.22.5 require ( github.com/aws/amazon-vpc-cni-k8s v1.18.1 - github.com/aws/aws-sdk-go v1.51.32 + github.com/aws/aws-sdk-go v1.55.5 github.com/go-logr/logr v1.4.2 github.com/go-logr/zapr v1.3.0 github.com/golang/mock v1.6.0 diff --git a/go.sum b/go.sum index 6069a9a2..e6bc54de 100644 --- a/go.sum +++ b/go.sum @@ -2,8 +2,8 @@ github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPd github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs= github.com/aws/amazon-vpc-cni-k8s v1.18.1 h1:u/OeBgnUUX6f3PCEOpA4dbG0+iZ71CnY6tEljjrl3iw= github.com/aws/amazon-vpc-cni-k8s v1.18.1/go.mod h1:m/J5GsxF0Th2iQTOE3ww4W9LFvwdC0tGyA9dIL4h6iQ= -github.com/aws/aws-sdk-go v1.51.32 h1:A6mPui7QP4mwmovyzgtdedbRbNur1Iu0/El7hBWNHms= -github.com/aws/aws-sdk-go v1.51.32/go.mod h1:LF8svs817+Nz+DmiMQKTO3ubZ/6IaTpq3TjupRn3Eqk= +github.com/aws/aws-sdk-go v1.55.5 h1:KKUZBfBoyqy5d3swXyiC7Q76ic40rYcbqH7qjh59kzU= +github.com/aws/aws-sdk-go v1.55.5/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= diff --git a/pkg/aws/ec2/api/wrapper.go b/pkg/aws/ec2/api/wrapper.go index 97a20f3b..ef3c7d1e 100644 --- a/pkg/aws/ec2/api/wrapper.go +++ b/pkg/aws/ec2/api/wrapper.go @@ -510,21 +510,21 @@ func (e *ec2Wrapper) getClientUsingAssumedRole(instanceRegion, roleARN, clusterN } e.log.Info("created rate limited http client", "qps", qps, "burst", burst) - // Get the regional sts end point - regionalSTSEndpoint, err := endpoints.DefaultResolver(). - EndpointFor("sts", aws.StringValue(userStsSession.Config.Region), endpoints.STSRegionalEndpointOption) - if err != nil { - return nil, fmt.Errorf("failed to get the regional sts endoint for region %s: %v", - *userStsSession.Config.Region, err) - } - + // GetPartition ID, SourceAccount and SourceARN roleARN = strings.Trim(roleARN, "\"") - sourceAcct, sourceArn, err := utils.GetSourceAcctAndArn(roleARN, region, clusterName) + sourceAcct, partitionID, sourceArn, err := utils.GetSourceAcctAndArn(roleARN, region, clusterName) if err != nil { return nil, err } + // Get the regional sts end point + regionalSTSEndpoint, err := e.getRegionalStsEndpoint(partitionID, region) + if err != nil { + return nil, fmt.Errorf("failed to get the regional sts endpoint for region %s: %v %v", + *userStsSession.Config.Region, err, partitionID) + } + regionalProvider := &stscreds.AssumeRoleProvider{ Client: e.createSTSClient(userStsSession, client, regionalSTSEndpoint, sourceAcct, sourceArn), RoleARN: roleARN, @@ -892,3 +892,35 @@ func (e *ec2Wrapper) DisassociateTrunkInterface(input *ec2.DisassociateTrunkInte } return err } + +func (e *ec2Wrapper) getRegionalStsEndpoint(partitionID, region string) (endpoints.ResolvedEndpoint, error) { + var partition *endpoints.Partition + var stsServiceID = "sts" + for _, p := range endpoints.DefaultPartitions() { + if partitionID == p.ID() { + partition = &p + break + } + } + if partition == nil { + return endpoints.ResolvedEndpoint{}, fmt.Errorf("partition %s not valid", partitionID) + } + + stsSvc, ok := partition.Services()[stsServiceID] + if !ok { + e.log.Info("STS service not found in partition, generating default endpoint.", "Partition:", partitionID) + // Add the host of the current instances region if the service doesn't already exists in the partition + // so we don't fail if the service is not present in the go sdk but matches the instances region. + res, err := partition.EndpointFor(stsServiceID, region, endpoints.STSRegionalEndpointOption, endpoints.ResolveUnknownServiceOption) + if err != nil { + return endpoints.ResolvedEndpoint{}, fmt.Errorf("error resolving endpoint for %s in partition %s. err: %v", region, partition.ID(), err) + } + return res, nil + } + + res, err := stsSvc.ResolveEndpoint(region, endpoints.STSRegionalEndpointOption) + if err != nil { + return endpoints.ResolvedEndpoint{}, fmt.Errorf("error resolving endpoint for %s in partition %s. err: %v", region, partition.ID(), err) + } + return res, nil +} diff --git a/pkg/aws/ec2/api/wrapper_test.go b/pkg/aws/ec2/api/wrapper_test.go new file mode 100644 index 00000000..281a8d59 --- /dev/null +++ b/pkg/aws/ec2/api/wrapper_test.go @@ -0,0 +1,65 @@ +package api + +import ( + "testing" +) + +func getMockEC2Wrapper() ec2Wrapper { + return ec2Wrapper{} +} +func Test_getRegionalStsEndpoint(t *testing.T) { + + ec2Wapper := getMockEC2Wrapper() + + type args struct { + partitionID string + region string + } + + tests := []struct { + name string + args args + want string + wantErr bool + }{ + { + name: "service doesn't exist in partition", + args: args{ + partitionID: "aws-iso-f", + region: "testregions", + }, + want: "https://sts.testregions.csp.hci.ic.gov", + wantErr: false, + }, + { + name: "region doesn't exist in partition", + args: args{ + partitionID: "aws", + region: "us-test-2", + }, + want: "https://sts.us-test-2.amazonaws.com", + wantErr: false, + }, + { + name: "region and service exist in partition", + args: args{ + partitionID: "aws", + region: "us-west-2", + }, + want: "https://sts.us-west-2.amazonaws.com", + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ec2Wapper.getRegionalStsEndpoint(tt.args.partitionID, tt.args.region) + if (err != nil) != tt.wantErr { + t.Errorf("getRegionalStsEndpoint() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got.URL != tt.want { + t.Errorf("getRegionalStsEndpoint() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/pkg/utils/helper.go b/pkg/utils/helper.go index d6b6eeac..14fda665 100644 --- a/pkg/utils/helper.go +++ b/pkg/utils/helper.go @@ -213,22 +213,22 @@ func IsNitroInstance(instanceType string) (bool, error) { } // GetSourceAcctAndArn constructs source acct and arn and return them for use -func GetSourceAcctAndArn(roleARN, region, clusterName string) (string, string, error) { +func GetSourceAcctAndArn(roleARN, region, clusterName string) (string, string, string, error) { // ARN format (https://docs.aws.amazon.com/IAM/latest/UserGuide/reference-arns.html) // arn:partition:service:region:account-id:resource-type/resource-id // IAM format, region is always blank // arn:aws:iam::account:role/role-name-with-path if !arn.IsARN(roleARN) { - return "", "", fmt.Errorf("incorrect ARN format for role %s", roleARN) + return "", "", "", fmt.Errorf("incorrect ARN format for role %s", roleARN) } else if region == "" { - return "", "", nil + return "", "", "", nil } parsedArn, err := arn.Parse(roleARN) if err != nil { - return "", "", err + return "", "", "", err } sourceArn := fmt.Sprintf("arn:%s:eks:%s:%s:cluster/%s", parsedArn.Partition, region, parsedArn.AccountID, clusterName) - return parsedArn.AccountID, sourceArn, nil + return parsedArn.AccountID, parsedArn.Partition, sourceArn, nil } diff --git a/pkg/utils/helper_test.go b/pkg/utils/helper_test.go index 34a4e4d2..aee3659b 100644 --- a/pkg/utils/helper_test.go +++ b/pkg/utils/helper_test.go @@ -538,26 +538,29 @@ func TestGetSourceAcctAndArn(t *testing.T) { clusterName := "test-cluster" region := "us-west-2" clusterARN := "arn:aws:eks:us-west-2:123456789876:cluster/test-cluster" - + partition := "aws" roleARN := "arn:aws:iam::123456789876:role/test-cluster" // test correct inputs - acct, arn, err := GetSourceAcctAndArn(roleARN, region, clusterName) + acct, part, arn, err := GetSourceAcctAndArn(roleARN, region, clusterName) assert.NoError(t, err, "no error should be returned with accurate role arn") + assert.Equal(t, partition, part, "correct partition should be retrieved") assert.Equal(t, accountID, acct, "correct account ID should be retrieved") assert.Equal(t, clusterARN, arn, "correct cluster arn should be retrieved") region = "us-gov-west-1" roleARN = "arn:aws-us-gov:iam::123456789876:role/test-cluster" clusterARN = "arn:aws-us-gov:eks:us-gov-west-1:123456789876:cluster/test-cluster" - acct, arn, err = GetSourceAcctAndArn(roleARN, region, clusterName) + partition = "aws-us-gov" + acct, part, arn, err = GetSourceAcctAndArn(roleARN, region, clusterName) assert.NoError(t, err, "no error should be returned with accurate aws-us-gov partition role arn") assert.Equal(t, accountID, acct, "correct account ID should be retrieved") + assert.Equal(t, partition, part, "correct patition should be retrieved") assert.Equal(t, clusterARN, arn, "correct gov partition cluster arn should be retrieved") // test error handling roleARN = "arn:aws:iam::123456789876" - _, _, err = GetSourceAcctAndArn(roleARN, region, clusterName) + _, _, _, err = GetSourceAcctAndArn(roleARN, region, clusterName) assert.Error(t, err, "error should be returned with inaccurate role arn is given") } @@ -569,8 +572,10 @@ func TestGetSourceAcctAndArn_NoRegion(t *testing.T) { roleARN := "arn:aws:iam::123456789876:role/test-cluster" // test correct inputs - acct, arn, err := GetSourceAcctAndArn(roleARN, region, clusterName) + acct, part, arn, err := GetSourceAcctAndArn(roleARN, region, clusterName) assert.NoError(t, err, "no error should be returned with accurate role arn") assert.Equal(t, "", acct, "correct account ID should be retrieved") assert.Equal(t, "", arn, "correct cluster arn should be retrieved") + assert.Equal(t, "", part, "correct partiton should be retrieved") + }