From 13f8baad30292c8b47381a4189c8d2a57dab0262 Mon Sep 17 00:00:00 2001 From: Paul Ingles Date: Tue, 20 Oct 2020 22:04:28 +0100 Subject: [PATCH] More improvements on region handling * Removed most of the region config setup from sts.DefaultGateway into a configBuilder, added more tests around configBuilder to confirm behaviour * Changed server to request server credentials with the server assume role after configuring for region, should address #368 * Regional endpoint adds a us-iso prefix to handle airgapped regions addressing #410 * Updated version of AWS SDK to 1.35 --- go.mod | 4 +- go.sum | 10 ++ pkg/aws/sts/aws_endpoint_resolver.go | 6 +- pkg/aws/sts/aws_endpoint_resolver_test.go | 31 ++++- pkg/aws/sts/gateway.go | 20 +-- pkg/aws/sts/gateway_test.go | 37 ------ pkg/aws/sts/kiam_configuration_builder.go | 69 ++++++++++ .../sts/kiam_configuration_builder_test.go | 123 ++++++++++++++++++ pkg/server/server.go | 18 ++- 9 files changed, 251 insertions(+), 67 deletions(-) delete mode 100644 pkg/aws/sts/gateway_test.go create mode 100644 pkg/aws/sts/kiam_configuration_builder.go create mode 100644 pkg/aws/sts/kiam_configuration_builder_test.go diff --git a/go.mod b/go.mod index 1cc99c96..72d4df70 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.13 require ( github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc // indirect github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf // indirect - github.com/aws/aws-sdk-go v1.25.34 + github.com/aws/aws-sdk-go v1.35.10 github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973 // indirect github.com/cenkalti/backoff v2.0.0+incompatible github.com/coreos/go-iptables v0.3.0 @@ -38,7 +38,7 @@ require ( github.com/stretchr/testify v1.4.0 // indirect github.com/uswitch/k8sc v0.0.0-20170525133932-475c8175b340 github.com/vmg/backoff v1.0.0 - golang.org/x/net v0.0.0-20190311183353-d8887717615a + golang.org/x/net v0.0.0-20200202094626-16171245cfb2 golang.org/x/sys v0.0.0-20200117145432-59e60aa80a0c // indirect golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2 // indirect google.golang.org/grpc v1.27.0 diff --git a/go.sum b/go.sum index ef1ddb49..0dd7934d 100644 --- a/go.sum +++ b/go.sum @@ -6,6 +6,8 @@ github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf h1:qet1QNfXsQxTZq github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/aws/aws-sdk-go v1.25.34 h1:roL040qe1npx1ToFeXYHOGp/nOpLbcIQHKZ5UeDIyIM= github.com/aws/aws-sdk-go v1.25.34/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= +github.com/aws/aws-sdk-go v1.35.10 h1:FsJtrOS7P+Qmq1rPTGgS/+qC1Y9eGuAJHvAZpZlhmb4= +github.com/aws/aws-sdk-go v1.35.10/go.mod h1:tlPOdRjfxPBpNIwqDj61rmsnA85v9jc0Ps9+muhnW+k= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973 h1:xJ4a3vCFaGF/jqvzLMYoU8P317H5OQ+Via4RmuPwCS0= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/cenkalti/backoff v2.0.0+incompatible h1:5IIPUHhlnUZbcHQsQou5k1Tn58nJkeJL9U+ig5CHJbY= @@ -73,6 +75,9 @@ github.com/imdario/mergo v0.3.4 h1:mKkfHkZWD8dC7WxKx3N9WCF0Y+dLau45704YQmY6H94= github.com/imdario/mergo v0.3.4/go.mod h1:2EnlNZ0deacrJVfApfmtdGgDfMuh/nq6Ok1EcJh5FfA= github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af h1:pmfjZENx5imkbgOkpRUYLnmbU7UEFbjtDA2hxJ1ichM= github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= +github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= +github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= +github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= github.com/json-iterator/go v0.0.0-20180315132816-ca39e5af3ece h1:3HJXp/18JmMk5sjBP3LDUBtWjczCvynxaeAF6b6kWp8= github.com/json-iterator/go v0.0.0-20180315132816-ca39e5af3ece/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= @@ -88,6 +93,7 @@ github.com/onsi/gomega v1.7.1 h1:K0jcRCwNQM3vFGh1ppMtDh/+7ApJrjldlX8fA0jDTLQ= github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_golang v0.9.0-pre1 h1:AWTOhsOI9qxeirTuA0A4By/1Es1+y9EcCGY6bBZ2fhM= @@ -121,6 +127,8 @@ golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73r golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190311183353-d8887717615a h1:oWX7TPOiFAMXLq8o0ikBYfCJVlRHBcsciT5bXOrH628= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20200202094626-16171245cfb2 h1:CCH4IOTTfewWjGOlSp+zGcjutRKlBEZQ6wTn8ozI/nI= +golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -186,6 +194,8 @@ gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWD gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.4 h1:/eiJrUcujPVeJ3xlSWaiNi3uSVmDGBK1pDHUHAnao1I= gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= k8s.io/api v0.0.0-20180521142803-feb48db456a5 h1:ZkJvJIvl22AqkIYbow7+ZkJCZ/Vf5TnLyJ1Q5UpFXEI= diff --git a/pkg/aws/sts/aws_endpoint_resolver.go b/pkg/aws/sts/aws_endpoint_resolver.go index 73c6e714..687e7b73 100644 --- a/pkg/aws/sts/aws_endpoint_resolver.go +++ b/pkg/aws/sts/aws_endpoint_resolver.go @@ -36,6 +36,10 @@ func regionalHostname(region string) (string, error) { hostname = fmt.Sprintf("%s.cn", hostname) } + if strings.HasPrefix(region, "us-iso") { + hostname = fmt.Sprintf("sts.%s.c2s.ic.gov", region) + } + if _, err := net.LookupHost(hostname); err != nil { return "", fmt.Errorf("Regional STS endpoint does not exist: %s", hostname) } @@ -43,7 +47,7 @@ func regionalHostname(region string) (string, error) { return hostname, nil } -func NewRegionalEndpointResolver(region string) (endpoints.Resolver, error) { +func newRegionalEndpointResolver(region string) (endpoints.Resolver, error) { if region == "" || strings.Contains(region,"fips") { return endpoints.DefaultResolver(), nil } diff --git a/pkg/aws/sts/aws_endpoint_resolver_test.go b/pkg/aws/sts/aws_endpoint_resolver_test.go index 4db51622..574a811f 100644 --- a/pkg/aws/sts/aws_endpoint_resolver_test.go +++ b/pkg/aws/sts/aws_endpoint_resolver_test.go @@ -19,7 +19,7 @@ import ( ) func TestUsesDefaultForOtherServices(t *testing.T) { - r, _ := NewRegionalEndpointResolver("eu-west-1") + r, _ := newRegionalEndpointResolver("eu-west-1") rd, err := r.EndpointFor(endpoints.S3ServiceID, endpoints.EuWest1RegionID) if err != nil { t.Error(err) @@ -30,7 +30,7 @@ func TestUsesDefaultForOtherServices(t *testing.T) { } func TestResolvesDefaultRegion(t *testing.T) { - resolver, _ := NewRegionalEndpointResolver("") + resolver, _ := newRegionalEndpointResolver("") resolved, err := resolver.EndpointFor(endpoints.StsServiceID, "") if err != nil { @@ -43,7 +43,7 @@ func TestResolvesDefaultRegion(t *testing.T) { } func TestResolvesUsingSpecifiedRegion(t *testing.T) { - resolver, _ := NewRegionalEndpointResolver("us-west-2") + resolver, _ := newRegionalEndpointResolver("us-west-2") resolved, err := resolver.EndpointFor(endpoints.StsServiceID, "") if err != nil { t.Error(err) @@ -55,7 +55,7 @@ func TestResolvesUsingSpecifiedRegion(t *testing.T) { } func TestResolvesEURegion(t *testing.T) { - resolver, _ := NewRegionalEndpointResolver("eu-west-1") + resolver, _ := newRegionalEndpointResolver("eu-west-1") resolved, err := resolver.EndpointFor(endpoints.StsServiceID, "") if err != nil { t.Error(err) @@ -67,7 +67,7 @@ func TestResolvesEURegion(t *testing.T) { } func TestAddsChinaPrefixForChineseRegions(t *testing.T) { - resolver, err := NewRegionalEndpointResolver("cn-north-1") + resolver, err := newRegionalEndpointResolver("cn-north-1") if err != nil { t.Error(err) } @@ -83,7 +83,7 @@ func TestAddsChinaPrefixForChineseRegions(t *testing.T) { } func TestUseDefaultForFIPS(t *testing.T) { - r, e := NewRegionalEndpointResolver("us-east-1-fips") + r, e := newRegionalEndpointResolver("us-east-1-fips") if e != nil { t.Error(e) } @@ -99,7 +99,7 @@ func TestUseDefaultForFIPS(t *testing.T) { } func TestGovGateway(t *testing.T) { - r, e := NewRegionalEndpointResolver("us-gov-east-1") + r, e := newRegionalEndpointResolver("us-gov-east-1") if e != nil { t.Error(e) } @@ -112,4 +112,21 @@ func TestGovGateway(t *testing.T) { if rd.URL != "https://sts.us-gov-east-1.amazonaws.com" { t.Error("unexpected", rd.URL) } +} + +// https://github.com/uswitch/kiam/issues/410 +func TestAirgappedRegion(t *testing.T) { + r, e := newRegionalEndpointResolver("us-iso-east-1") + if e != nil { + t.Error(e) + } + + rd, e := r.EndpointFor(endpoints.StsServiceID, "us-iso-east-1") + if e != nil { + t.Error(e) + } + + if rd.URL != "https://sts.us-iso-east-1.c2s.ic.gov" { + t.Error("unexpected", rd.URL) + } } \ No newline at end of file diff --git a/pkg/aws/sts/gateway.go b/pkg/aws/sts/gateway.go index 3f0e87d4..f6892ef8 100644 --- a/pkg/aws/sts/gateway.go +++ b/pkg/aws/sts/gateway.go @@ -18,8 +18,6 @@ import ( "time" "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/credentials/stscreds" - "github.com/aws/aws-sdk-go/aws/endpoints" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/sts" "github.com/prometheus/client_golang/prometheus" @@ -31,25 +29,9 @@ type STSGateway interface { type DefaultSTSGateway struct { session *session.Session - resolver endpoints.Resolver } -func DefaultGateway(assumeRoleArn, region string) (*DefaultSTSGateway, error) { - config := aws.NewConfig().WithCredentialsChainVerboseErrors(true) - - if assumeRoleArn != "" { - config.WithCredentials(stscreds.NewCredentials(session.Must(session.NewSession()), assumeRoleArn)) - } - - if region != "" { - resolver, err := NewRegionalEndpointResolver(region) - if err != nil { - return nil, err - } - - config.WithRegion(region).WithEndpointResolver(resolver) - } - +func DefaultGateway(config *aws.Config) (*DefaultSTSGateway, error) { session := session.Must(session.NewSession(config)) return &DefaultSTSGateway{session: session}, nil } diff --git a/pkg/aws/sts/gateway_test.go b/pkg/aws/sts/gateway_test.go deleted file mode 100644 index 247d71ea..00000000 --- a/pkg/aws/sts/gateway_test.go +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright 2017 uSwitch -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -package sts - -import ( - "testing" - - "github.com/aws/aws-sdk-go/service/sts" -) - -func TestConfiguresSessionWithRegion(t *testing.T) { - gateway, err := DefaultGateway("", "us-west-2") - if err != nil { - t.Error(err) - } - - config := gateway.session.ClientConfig(sts.EndpointsID) - - if config.SigningRegion != "us-west-2" { - t.Error("Unexpected region. Region was: ", config.SigningRegion) - } - - if config.Endpoint != "https://sts.us-west-2.amazonaws.com" { - t.Error("Unexpected regional endpoint. Endpoint was: ", config.Endpoint) - } -} diff --git a/pkg/aws/sts/kiam_configuration_builder.go b/pkg/aws/sts/kiam_configuration_builder.go new file mode 100644 index 00000000..46023bce --- /dev/null +++ b/pkg/aws/sts/kiam_configuration_builder.go @@ -0,0 +1,69 @@ +// Copyright 2017 uSwitch +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package sts + +import ( + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/credentials/stscreds" + "github.com/aws/aws-sdk-go/aws/session" +) + +type awsConfigCredentialsProvider interface { + NewCredentials(cfg *aws.Config, assumeRoleARN string) *credentials.Credentials +} + +type STSCredentialsProvider struct { +} + +func (s *STSCredentialsProvider) NewCredentials(cfg *aws.Config, assumeRoleARN string) *credentials.Credentials { + return stscreds.NewCredentials(session.Must(session.NewSession(cfg)), assumeRoleARN) +} + +func NewSTSCredentialsProvider() *STSCredentialsProvider { + return &STSCredentialsProvider{} +} + +type configBuilder struct { + config *aws.Config +} + +// Builds the necessary AWS config for Kiam's server +func NewServerConfigBuilder() *configBuilder { + return &configBuilder{config: aws.NewConfig().WithCredentialsChainVerboseErrors(true)} +} + +func (c *configBuilder) WithRegion(region string) (*configBuilder, error) { + resolver, err := newRegionalEndpointResolver(region) + if err != nil { + return nil, err + } + + c.config.WithRegion(region).WithEndpointResolver(resolver) + + return c, nil +} + +func (c *configBuilder) WithCredentialsFromAssumedRole(provider awsConfigCredentialsProvider, assumeRoleARN string) *configBuilder { + if assumeRoleARN == "" { + return c + } + + c.config.WithCredentials(provider.NewCredentials(c.config, assumeRoleARN)) + return c +} + +func (c *configBuilder) Config() *aws.Config { + return c.config +} \ No newline at end of file diff --git a/pkg/aws/sts/kiam_configuration_builder_test.go b/pkg/aws/sts/kiam_configuration_builder_test.go new file mode 100644 index 00000000..7d4c3402 --- /dev/null +++ b/pkg/aws/sts/kiam_configuration_builder_test.go @@ -0,0 +1,123 @@ +// Copyright 2017 uSwitch +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package sts + +import ( + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/endpoints" + "testing" +) + +func TestDefaultConfig(t *testing.T) { + b := NewServerConfigBuilder() + + if b.Config().Region != nil { + t.Error("expected nil region, was", *b.Config().Region) + } + + if !*b.Config().CredentialsChainVerboseErrors { + t.Error("expected verbose errors") + } +} + +func TestConfigWithRegion(t *testing.T) { + b, _ := NewServerConfigBuilder().WithRegion(endpoints.UsEast1RegionID) + + if *b.Config().Region != endpoints.UsEast1RegionID { + t.Error("unexpected region", *b.Config().Region) + } + + // it should also configure with our custom endpoint resolver + _, ok := b.Config().EndpointResolver.(*regionalEndpointResolver) + if !ok { + t.Errorf("expected endpoint resolver to be castable to *regionalEndpointResolver, was %T", b.Config().EndpointResolver) + } +} + +func TestWithCredentials(t *testing.T) { + const accessKeyID = "id" + creds := credentials.NewStaticCredentials(accessKeyID, "secret", "token") + p := newStubCredentialsProvider(creds) + + b := NewServerConfigBuilder() + b.WithCredentialsFromAssumedRole(p, "my test role") + + if b.Config().Credentials != creds { + t.Errorf("expected same credentials, was %v", b.Config().Credentials) + } + + v, _ := b.Config().Credentials.Get() + if v.AccessKeyID != accessKeyID { + t.Error("unexpected access key", v.AccessKeyID) + } +} + +func TestWithEmptyAssumeRole(t *testing.T) { + creds := credentials.NewStaticCredentials("foo", "secret", "token") + p := newStubCredentialsProvider(creds) + + b := NewServerConfigBuilder() + b.WithCredentialsFromAssumedRole(p, "") + + if p.calls != 0 { + t.Error("shouldn't have called provider with empty role") + } +} + +func TestConfiguresWithCredentialsFromProvider(t *testing.T) { + const accessKeyID = "AccessKeyID-example" + creds := credentials.NewStaticCredentials(accessKeyID, "secret", "token") + stubProvider := newStubCredentialsProvider(creds) + + builder := NewServerConfigBuilder() + builder.WithCredentialsFromAssumedRole(stubProvider, "my test role") + + c, _ := builder.Config().Credentials.Get() + if c.AccessKeyID != accessKeyID { + t.Errorf("expected id as access key, was %s", c.AccessKeyID) + } +} + +func TestProvidesConfigurationToCredentialsProvider(t *testing.T) { + creds := credentials.NewStaticCredentials("foo", "secret", "token") + stubProvider := newStubCredentialsProvider(creds) + + builder := NewServerConfigBuilder() + builder.WithCredentialsFromAssumedRole(stubProvider, "my test role") + + if stubProvider.requestedConfig != builder.Config() { + t.Error("expected builder config to be passed to credentials provider") + } +} + + +func newStubCredentialsProvider(creds *credentials.Credentials) *stubCredentialsProvider { + return &stubCredentialsProvider{ + credentials: creds, + requestedConfig: nil, + } +} + +type stubCredentialsProvider struct { + credentials *credentials.Credentials + requestedConfig *aws.Config + calls int +} + +func (s *stubCredentialsProvider) NewCredentials(cfg *aws.Config, assumeRoleARN string) *credentials.Credentials { + s.requestedConfig = cfg + s.calls += 1 + return s.credentials +} diff --git a/pkg/server/server.go b/pkg/server/server.go index c1536440..f8a4aada 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -168,16 +168,32 @@ func newRoleARNResolver(config *Config) (sts.ARNResolver, error) { return sts.DefaultResolver(config.RoleBaseARN), nil } +func newSTSGateway(config *Config) (sts.STSGateway, error) { + cfg, err := sts.NewServerConfigBuilder().WithRegion(config.Region) + if err != nil { + return nil, err + } + cfg.WithCredentialsFromAssumedRole(sts.NewSTSCredentialsProvider(), config.AssumeRoleArn) + stsGateway, err := sts.DefaultGateway(cfg.Config()) + if err != nil { + return nil, err + } + + return stsGateway, nil +} + // NewServer constructs a new server. func NewServer(config *Config) (_ *KiamServer, err error) { arnResolver, err := newRoleARNResolver(config) if err != nil { return nil, err } - stsGateway, err := sts.DefaultGateway(arnResolver.Resolve(config.AssumeRoleArn), config.Region) + + stsGateway, err := newSTSGateway(config) if err != nil { return nil, err } + credentialsCache := sts.DefaultCache( stsGateway, config.SessionName,