Skip to content

Commit

Permalink
Update aws-sdk-go and change way to get regional sts endpoint (#466)
Browse files Browse the repository at this point in the history
  • Loading branch information
jaydeokar authored Sep 12, 2024
1 parent 712887d commit 19ed9ef
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 22 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ go 1.22.5

require (
github.com/aws/amazon-vpc-cni-k8s v1.18.1
github.com/aws/aws-sdk-go v1.51.32
github.com/aws/aws-sdk-go v1.55.5
github.com/go-logr/logr v1.4.2
github.com/go-logr/zapr v1.3.0
github.com/golang/mock v1.6.0
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPd
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs=
github.com/aws/amazon-vpc-cni-k8s v1.18.1 h1:u/OeBgnUUX6f3PCEOpA4dbG0+iZ71CnY6tEljjrl3iw=
github.com/aws/amazon-vpc-cni-k8s v1.18.1/go.mod h1:m/J5GsxF0Th2iQTOE3ww4W9LFvwdC0tGyA9dIL4h6iQ=
github.com/aws/aws-sdk-go v1.51.32 h1:A6mPui7QP4mwmovyzgtdedbRbNur1Iu0/El7hBWNHms=
github.com/aws/aws-sdk-go v1.51.32/go.mod h1:LF8svs817+Nz+DmiMQKTO3ubZ/6IaTpq3TjupRn3Eqk=
github.com/aws/aws-sdk-go v1.55.5 h1:KKUZBfBoyqy5d3swXyiC7Q76ic40rYcbqH7qjh59kzU=
github.com/aws/aws-sdk-go v1.55.5/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
Expand Down
50 changes: 41 additions & 9 deletions pkg/aws/ec2/api/wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -510,21 +510,21 @@ func (e *ec2Wrapper) getClientUsingAssumedRole(instanceRegion, roleARN, clusterN
}
e.log.Info("created rate limited http client", "qps", qps, "burst", burst)

// Get the regional sts end point
regionalSTSEndpoint, err := endpoints.DefaultResolver().
EndpointFor("sts", aws.StringValue(userStsSession.Config.Region), endpoints.STSRegionalEndpointOption)
if err != nil {
return nil, fmt.Errorf("failed to get the regional sts endoint for region %s: %v",
*userStsSession.Config.Region, err)
}

// GetPartition ID, SourceAccount and SourceARN
roleARN = strings.Trim(roleARN, "\"")

sourceAcct, sourceArn, err := utils.GetSourceAcctAndArn(roleARN, region, clusterName)
sourceAcct, partitionID, sourceArn, err := utils.GetSourceAcctAndArn(roleARN, region, clusterName)
if err != nil {
return nil, err
}

// Get the regional sts end point
regionalSTSEndpoint, err := e.getRegionalStsEndpoint(partitionID, region)
if err != nil {
return nil, fmt.Errorf("failed to get the regional sts endpoint for region %s: %v %v",
*userStsSession.Config.Region, err, partitionID)
}

regionalProvider := &stscreds.AssumeRoleProvider{
Client: e.createSTSClient(userStsSession, client, regionalSTSEndpoint, sourceAcct, sourceArn),
RoleARN: roleARN,
Expand Down Expand Up @@ -892,3 +892,35 @@ func (e *ec2Wrapper) DisassociateTrunkInterface(input *ec2.DisassociateTrunkInte
}
return err
}

func (e *ec2Wrapper) getRegionalStsEndpoint(partitionID, region string) (endpoints.ResolvedEndpoint, error) {
var partition *endpoints.Partition
var stsServiceID = "sts"
for _, p := range endpoints.DefaultPartitions() {
if partitionID == p.ID() {
partition = &p
break
}
}
if partition == nil {
return endpoints.ResolvedEndpoint{}, fmt.Errorf("partition %s not valid", partitionID)
}

stsSvc, ok := partition.Services()[stsServiceID]
if !ok {
e.log.Info("STS service not found in partition, generating default endpoint.", "Partition:", partitionID)
// Add the host of the current instances region if the service doesn't already exists in the partition
// so we don't fail if the service is not present in the go sdk but matches the instances region.
res, err := partition.EndpointFor(stsServiceID, region, endpoints.STSRegionalEndpointOption, endpoints.ResolveUnknownServiceOption)
if err != nil {
return endpoints.ResolvedEndpoint{}, fmt.Errorf("error resolving endpoint for %s in partition %s. err: %v", region, partition.ID(), err)
}
return res, nil
}

res, err := stsSvc.ResolveEndpoint(region, endpoints.STSRegionalEndpointOption)
if err != nil {
return endpoints.ResolvedEndpoint{}, fmt.Errorf("error resolving endpoint for %s in partition %s. err: %v", region, partition.ID(), err)
}
return res, nil
}
65 changes: 65 additions & 0 deletions pkg/aws/ec2/api/wrapper_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package api

import (
"testing"
)

func getMockEC2Wrapper() ec2Wrapper {
return ec2Wrapper{}
}
func Test_getRegionalStsEndpoint(t *testing.T) {

ec2Wapper := getMockEC2Wrapper()

type args struct {
partitionID string
region string
}

tests := []struct {
name string
args args
want string
wantErr bool
}{
{
name: "service doesn't exist in partition",
args: args{
partitionID: "aws-iso-f",
region: "testregions",
},
want: "https://sts.testregions.csp.hci.ic.gov",
wantErr: false,
},
{
name: "region doesn't exist in partition",
args: args{
partitionID: "aws",
region: "us-test-2",
},
want: "https://sts.us-test-2.amazonaws.com",
wantErr: false,
},
{
name: "region and service exist in partition",
args: args{
partitionID: "aws",
region: "us-west-2",
},
want: "https://sts.us-west-2.amazonaws.com",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ec2Wapper.getRegionalStsEndpoint(tt.args.partitionID, tt.args.region)
if (err != nil) != tt.wantErr {
t.Errorf("getRegionalStsEndpoint() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got.URL != tt.want {
t.Errorf("getRegionalStsEndpoint() = %v, want %v", got, tt.want)
}
})
}
}
10 changes: 5 additions & 5 deletions pkg/utils/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,22 +213,22 @@ func IsNitroInstance(instanceType string) (bool, error) {
}

// GetSourceAcctAndArn constructs source acct and arn and return them for use
func GetSourceAcctAndArn(roleARN, region, clusterName string) (string, string, error) {
func GetSourceAcctAndArn(roleARN, region, clusterName string) (string, string, string, error) {
// ARN format (https://docs.aws.amazon.com/IAM/latest/UserGuide/reference-arns.html)
// arn:partition:service:region:account-id:resource-type/resource-id
// IAM format, region is always blank
// arn:aws:iam::account:role/role-name-with-path
if !arn.IsARN(roleARN) {
return "", "", fmt.Errorf("incorrect ARN format for role %s", roleARN)
return "", "", "", fmt.Errorf("incorrect ARN format for role %s", roleARN)
} else if region == "" {
return "", "", nil
return "", "", "", nil
}

parsedArn, err := arn.Parse(roleARN)
if err != nil {
return "", "", err
return "", "", "", err
}

sourceArn := fmt.Sprintf("arn:%s:eks:%s:%s:cluster/%s", parsedArn.Partition, region, parsedArn.AccountID, clusterName)
return parsedArn.AccountID, sourceArn, nil
return parsedArn.AccountID, parsedArn.Partition, sourceArn, nil
}
15 changes: 10 additions & 5 deletions pkg/utils/helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -538,26 +538,29 @@ func TestGetSourceAcctAndArn(t *testing.T) {
clusterName := "test-cluster"
region := "us-west-2"
clusterARN := "arn:aws:eks:us-west-2:123456789876:cluster/test-cluster"

partition := "aws"
roleARN := "arn:aws:iam::123456789876:role/test-cluster"

// test correct inputs
acct, arn, err := GetSourceAcctAndArn(roleARN, region, clusterName)
acct, part, arn, err := GetSourceAcctAndArn(roleARN, region, clusterName)
assert.NoError(t, err, "no error should be returned with accurate role arn")
assert.Equal(t, partition, part, "correct partition should be retrieved")
assert.Equal(t, accountID, acct, "correct account ID should be retrieved")
assert.Equal(t, clusterARN, arn, "correct cluster arn should be retrieved")

region = "us-gov-west-1"
roleARN = "arn:aws-us-gov:iam::123456789876:role/test-cluster"
clusterARN = "arn:aws-us-gov:eks:us-gov-west-1:123456789876:cluster/test-cluster"
acct, arn, err = GetSourceAcctAndArn(roleARN, region, clusterName)
partition = "aws-us-gov"
acct, part, arn, err = GetSourceAcctAndArn(roleARN, region, clusterName)
assert.NoError(t, err, "no error should be returned with accurate aws-us-gov partition role arn")
assert.Equal(t, accountID, acct, "correct account ID should be retrieved")
assert.Equal(t, partition, part, "correct patition should be retrieved")
assert.Equal(t, clusterARN, arn, "correct gov partition cluster arn should be retrieved")

// test error handling
roleARN = "arn:aws:iam::123456789876"
_, _, err = GetSourceAcctAndArn(roleARN, region, clusterName)
_, _, _, err = GetSourceAcctAndArn(roleARN, region, clusterName)
assert.Error(t, err, "error should be returned with inaccurate role arn is given")
}

Expand All @@ -569,8 +572,10 @@ func TestGetSourceAcctAndArn_NoRegion(t *testing.T) {
roleARN := "arn:aws:iam::123456789876:role/test-cluster"

// test correct inputs
acct, arn, err := GetSourceAcctAndArn(roleARN, region, clusterName)
acct, part, arn, err := GetSourceAcctAndArn(roleARN, region, clusterName)
assert.NoError(t, err, "no error should be returned with accurate role arn")
assert.Equal(t, "", acct, "correct account ID should be retrieved")
assert.Equal(t, "", arn, "correct cluster arn should be retrieved")
assert.Equal(t, "", part, "correct partiton should be retrieved")

}

0 comments on commit 19ed9ef

Please sign in to comment.