diff --git a/controller/controller.go b/controller/controller.go index 183d9ef..273920a 100644 --- a/controller/controller.go +++ b/controller/controller.go @@ -72,7 +72,7 @@ func (c *EgressController) Run(ctx context.Context) { select { case <-time.After(c.interval): - err := c.provider.Ensure(c.configsCache) + err := c.provider.Ensure(ctx, c.configsCache) if err != nil { log.Errorf("Failed to ensure configuration: %v", err) continue @@ -87,7 +87,7 @@ func (c *EgressController) Run(ctx context.Context) { c.configsCache[config.Resource] = config.IPAddresses } - err := c.provider.Ensure(c.configsCache) + err := c.provider.Ensure(ctx, c.configsCache) if err != nil { log.Errorf("Failed to ensure configuration: %v", err) continue diff --git a/go.mod b/go.mod index dd5ff66..7045dc0 100644 --- a/go.mod +++ b/go.mod @@ -7,10 +7,13 @@ toolchain go1.22.4 require ( github.com/alecthomas/kingpin/v2 v2.4.0 github.com/apparentlymart/go-cidr v1.1.0 - github.com/aws/aws-sdk-go v1.54.7 + github.com/aws/aws-sdk-go-v2 v1.30.0 + github.com/aws/aws-sdk-go-v2/config v1.27.21 + github.com/aws/aws-sdk-go-v2/service/cloudformation v1.52.1 + github.com/aws/aws-sdk-go-v2/service/ec2 v1.165.1 + github.com/aws/smithy-go v1.20.2 github.com/crewjam/go-cloudformation v0.0.0-20180605015303-38e5b663797c github.com/google/uuid v1.6.0 - github.com/linki/instrumented_http v0.3.0 github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.19.1 github.com/sirupsen/logrus v1.9.3 @@ -22,6 +25,16 @@ require ( require ( github.com/alecthomas/units v0.0.0-20231202071711-9a357b53e9c9 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.17.21 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.8 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.12 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.12 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.14 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.21.1 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.25.1 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.29.1 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect diff --git a/go.sum b/go.sum index 6e7a307..a827b27 100644 --- a/go.sum +++ b/go.sum @@ -4,8 +4,36 @@ github.com/alecthomas/units v0.0.0-20231202071711-9a357b53e9c9 h1:ez/4by2iGztzR4 github.com/alecthomas/units v0.0.0-20231202071711-9a357b53e9c9/go.mod h1:OMCwj8VM1Kc9e19TLln2VL61YJF0x1XFtfdL4JdbSyE= github.com/apparentlymart/go-cidr v1.1.0 h1:2mAhrMoF+nhXqxTzSZMUzDHkLjmIHC+Zzn4tdgBZjnU= github.com/apparentlymart/go-cidr v1.1.0/go.mod h1:EBcsNrHc3zQeuaeCeCtQruQm+n9/YjEn/vI25Lg7Gwc= -github.com/aws/aws-sdk-go v1.54.7 h1:k1wJ+NMOsXgq/Lsa0y1mS0DFoDeHFPcz2OjCq5H5Mjg= -github.com/aws/aws-sdk-go v1.54.7/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU= +github.com/aws/aws-sdk-go-v2 v1.30.0 h1:6qAwtzlfcTtcL8NHtbDQAqgM5s6NDipQTkPxyH/6kAA= +github.com/aws/aws-sdk-go-v2 v1.30.0/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM= +github.com/aws/aws-sdk-go-v2/config v1.27.21 h1:yPX3pjGCe2hJsetlmGNB4Mngu7UPmvWPzzWCv1+boeM= +github.com/aws/aws-sdk-go-v2/config v1.27.21/go.mod h1:4XtlEU6DzNai8RMbjSF5MgGZtYvrhBP/aKZcRtZAVdM= +github.com/aws/aws-sdk-go-v2/credentials v1.17.21 h1:pjAqgzfgFhTv5grc7xPHtXCAaMapzmwA7aU+c/SZQGw= +github.com/aws/aws-sdk-go-v2/credentials v1.17.21/go.mod h1:nhK6PtBlfHTUDVmBLr1dg+WHCOCK+1Fu/WQyVHPsgNQ= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.8 h1:FR+oWPFb/8qMVYMWN98bUZAGqPvLHiyqg1wqQGfUAXY= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.8/go.mod h1:EgSKcHiuuakEIxJcKGzVNWh5srVAQ3jKaSrBGRYvM48= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.12 h1:SJ04WXGTwnHlWIODtC5kJzKbeuHt+OUNOgKg7nfnUGw= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.12/go.mod h1:FkpvXhA92gb3GE9LD6Og0pHHycTxW7xGpnEh5E7Opwo= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.12 h1:hb5KgeYfObi5MHkSSZMEudnIvX30iB+E21evI4r6BnQ= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.12/go.mod h1:CroKe/eWJdyfy9Vx4rljP5wTUjNJfb+fPz1uMYUhEGM= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 h1:hT8rVHwugYE2lEfdFE0QWVo81lF7jMrYJVDWI+f+VxU= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0/go.mod h1:8tu/lYfQfFe6IGnaOdrpVgEL2IrrDOf6/m9RQum4NkY= +github.com/aws/aws-sdk-go-v2/service/cloudformation v1.52.1 h1:Ts+mCjOtt8o2k2vnWnX/0sE0eSmEVWBvfJkNrNMQlAo= +github.com/aws/aws-sdk-go-v2/service/cloudformation v1.52.1/go.mod h1:IrWhabzdTEc651GAq7rgst/SYcEqqcD7Avr82m28AAU= +github.com/aws/aws-sdk-go-v2/service/ec2 v1.165.1 h1:LkSnU1c9JKJyXYcwpWgQGuwctwv3pDenMUgH2CmLd1A= +github.com/aws/aws-sdk-go-v2/service/ec2 v1.165.1/go.mod h1:Wv7N3iFOKVsZNIaw9MOBUmwCkX6VMmQQRFhMrHtNGno= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2 h1:Ji0DY1xUsUr3I8cHps0G+XM3WWU16lP6yG8qu1GAZAs= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2/go.mod h1:5CsjAbs3NlGQyZNFACh+zztPDI7fU6eW9QsxjfnuBKg= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.14 h1:zSDPny/pVnkqABXYRicYuPf9z2bTqfH13HT3v6UheIk= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.14/go.mod h1:3TTcI5JSzda1nw/pkVC9dhgLre0SNBFj2lYS4GctXKI= +github.com/aws/aws-sdk-go-v2/service/sso v1.21.1 h1:sd0BsnAvLH8gsp2e3cbaIr+9D7T1xugueQ7V/zUAsS4= +github.com/aws/aws-sdk-go-v2/service/sso v1.21.1/go.mod h1:lcQG/MmxydijbeTOp04hIuJwXGWPZGI3bwdFDGRTv14= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.25.1 h1:1uEFNNskK/I1KoZ9Q8wJxMz5V9jyBlsiaNrM7vA3YUQ= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.25.1/go.mod h1:z0P8K+cBIsFXUr5rzo/psUeJ20XjPN0+Nn8067Nd+E4= +github.com/aws/aws-sdk-go-v2/service/sts v1.29.1 h1:myX5CxqXE0QMZNja6FA1/FSE3Vu1rVmeUmpJMMzeZg0= +github.com/aws/aws-sdk-go-v2/service/sts v1.29.1/go.mod h1:N2mQiucsO0VwK9CYuS4/c2n6Smeh1v47Rz3dWCPFLdE= +github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q= +github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= @@ -69,8 +97,6 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= -github.com/linki/instrumented_http v0.3.0 h1:dsN92+mXpfZtjJraartcQ99jnuw7fqsnPDjr85ma2dA= -github.com/linki/instrumented_http v0.3.0/go.mod h1:pjYbItoegfuVi2GUOMhEqzvm/SJKuEL3H0tc8QRLRFk= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= diff --git a/main.go b/main.go index 237bc61..631d810 100644 --- a/main.go +++ b/main.go @@ -10,6 +10,7 @@ import ( "time" kingpin "github.com/alecthomas/kingpin/v2" + "github.com/aws/aws-sdk-go-v2/config" "github.com/prometheus/client_golang/prometheus/promhttp" log "github.com/sirupsen/logrus" "github.com/szuecs/kube-static-egress-controller/controller" @@ -77,16 +78,20 @@ func NewConfig() *Config { } } -func newProvider(clusterID, controllerID string, dry bool, name, vpcID string, clusterIDTagPrefix string, natCidrBlocks, availabilityZones []string, stackTerminationProtection bool, additionalStackTags map[string]string) provider.Provider { +func newProvider(clusterID, controllerID string, dry bool, name, vpcID string, clusterIDTagPrefix string, natCidrBlocks, availabilityZones []string, stackTerminationProtection bool, additionalStackTags map[string]string) (provider.Provider, error) { switch name { case aws.ProviderName: - return aws.NewAWSProvider(clusterID, controllerID, dry, vpcID, clusterIDTagPrefix, natCidrBlocks, availabilityZones, stackTerminationProtection, additionalStackTags) + cfg, err := config.LoadDefaultConfig(context.TODO()) + if err != nil { + return nil, err + } + return aws.NewAWSProvider(cfg, clusterID, controllerID, dry, vpcID, clusterIDTagPrefix, natCidrBlocks, availabilityZones, stackTerminationProtection, additionalStackTags) case noop.ProviderName: - return noop.NewNoopProvider() + return noop.NewNoopProvider(), nil default: - log.Fatalf("Unkown provider: %s", name) + return nil, fmt.Errorf("Unkown provider: %s", name) } - return nil + return nil, nil } func allLogLevelsAsStrings() []string { @@ -157,7 +162,10 @@ func main() { log.SetLevel(ll) log.Debugf("config: %+v", cfg) - p := newProvider(cfg.ClusterID, cfg.ControllerID, cfg.DryRun, cfg.Provider, cfg.VPCID, cfg.ClusterIDTagPrefix, cfg.NatCidrBlocks, cfg.AvailabilityZones, cfg.StackTerminationProtection, cfg.AdditionalStackTags) + p, err := newProvider(cfg.ClusterID, cfg.ControllerID, cfg.DryRun, cfg.Provider, cfg.VPCID, cfg.ClusterIDTagPrefix, cfg.NatCidrBlocks, cfg.AvailabilityZones, cfg.StackTerminationProtection, cfg.AdditionalStackTags) + if err != nil { + log.Fatalf("Failed to create provider: %v", err) + } configsChan := make(chan provider.EgressConfig) cmWatcher, err := kube.NewConfigMapWatcher(newKubeClient(), cfg.Namespace, "egress=static", configsChan) diff --git a/provider/aws/aws.go b/provider/aws/aws.go index aec01f8..774424e 100644 --- a/provider/aws/aws.go +++ b/provider/aws/aws.go @@ -9,16 +9,14 @@ import ( "strings" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/aws/client" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/cloudformation" - "github.com/aws/aws-sdk-go/service/cloudformation/cloudformationiface" - "github.com/aws/aws-sdk-go/service/ec2" - "github.com/aws/aws-sdk-go/service/ec2/ec2iface" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/transport/http" + "github.com/aws/aws-sdk-go-v2/service/cloudformation" + cftypes "github.com/aws/aws-sdk-go-v2/service/cloudformation/types" + "github.com/aws/aws-sdk-go-v2/service/ec2" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" + "github.com/aws/smithy-go" cft "github.com/crewjam/go-cloudformation" - "github.com/linki/instrumented_http" "github.com/pkg/errors" log "github.com/sirupsen/logrus" "github.com/szuecs/kube-static-egress-controller/provider" @@ -39,12 +37,12 @@ const ( ) var ( - errCreateFailed = fmt.Errorf("wait for stack failed with %s", cloudformation.StackStatusCreateFailed) - errRollbackComplete = fmt.Errorf("wait for stack failed with %s", cloudformation.StackStatusRollbackComplete) - errUpdateRollbackComplete = fmt.Errorf("wait for stack failed with %s", cloudformation.StackStatusUpdateRollbackComplete) - errRollbackFailed = fmt.Errorf("wait for stack failed with %s", cloudformation.StackStatusRollbackFailed) - errUpdateRollbackFailed = fmt.Errorf("wait for stack failed with %s", cloudformation.StackStatusUpdateRollbackFailed) - errDeleteFailed = fmt.Errorf("wait for stack failed with %s", cloudformation.StackStatusDeleteFailed) + errCreateFailed = fmt.Errorf("wait for stack failed with %s", cftypes.StackStatusCreateFailed) + errRollbackComplete = fmt.Errorf("wait for stack failed with %s", cftypes.StackStatusRollbackComplete) + errUpdateRollbackComplete = fmt.Errorf("wait for stack failed with %s", cftypes.StackStatusUpdateRollbackComplete) + errRollbackFailed = fmt.Errorf("wait for stack failed with %s", cftypes.StackStatusRollbackFailed) + errUpdateRollbackFailed = fmt.Errorf("wait for stack failed with %s", cftypes.StackStatusUpdateRollbackFailed) + errDeleteFailed = fmt.Errorf("wait for stack failed with %s", cftypes.StackStatusDeleteFailed) errTimeoutExceeded = fmt.Errorf("wait for stack timeout exceeded") ) @@ -56,8 +54,8 @@ type AWSProvider struct { vpcID string natCidrBlocks []string availabilityZones []string - cloudformation cloudformationiface.CloudFormationAPI - ec2 ec2iface.EC2API + cloudformation cloudformationAPI + ec2 ec2API stackTerminationProtection bool additionalStackTags map[string]string logger *log.Entry @@ -71,12 +69,11 @@ type stackSpec struct { timeoutInMinutes uint template string stackTerminationProtection bool - tags []*cloudformation.Tag + tags []cftypes.Tag } -func NewAWSProvider(clusterID, controllerID string, dry bool, vpcID string, clusterIDTagPrefix string, natCidrBlocks, availabilityZones []string, stackTerminationProtection bool, additionalStackTags map[string]string) *AWSProvider { +func NewAWSProvider(cfg aws.Config, clusterID, controllerID string, dry bool, vpcID string, clusterIDTagPrefix string, natCidrBlocks, availabilityZones []string, stackTerminationProtection bool, additionalStackTags map[string]string) (*AWSProvider, error) { // TODO: find vpcID at startup - p := defaultConfigProvider() return &AWSProvider{ clusterID: clusterID, clusterIDTagPrefix: clusterIDTagPrefix, @@ -85,38 +82,38 @@ func NewAWSProvider(clusterID, controllerID string, dry bool, vpcID string, clus vpcID: vpcID, natCidrBlocks: natCidrBlocks, availabilityZones: availabilityZones, - cloudformation: cloudformation.New(p), - ec2: ec2.New(p), + cloudformation: cloudformation.NewFromConfig(cfg), + ec2: ec2.NewFromConfig(cfg), stackTerminationProtection: stackTerminationProtection, additionalStackTags: additionalStackTags, logger: log.WithFields(log.Fields{"provider": ProviderName}), - } + }, nil } func (p AWSProvider) String() string { return ProviderName } -func (p *AWSProvider) Ensure(configs map[provider.Resource]map[string]*net.IPNet) error { - stack, err := p.getEgressStack() +func (p *AWSProvider) Ensure(ctx context.Context, configs map[provider.Resource]map[string]*net.IPNet) error { + stack, err := p.getEgressStack(ctx) if err != nil { return err } // don't do anything if the stack doesn't exist and the config is empty - if len(configs) == 0 && stack == nil { + if len(configs) == 0 && stack.StackName == nil { return nil } - spec, err := p.generateStackSpec(configs) + spec, err := p.generateStackSpec(ctx, configs) if err != nil { return errors.Wrap(err, "failed to generate stack spec") } // create new stack if it doesn't already exists - if stack == nil { + if stack.StackName == nil { p.logger.Infof("Creating CF stack with config: %v", configs) - err := p.createCFStack(spec) + err := p.createCFStack(ctx, spec) if err != nil { return errors.Wrap(err, "failed to create CF stack") } @@ -124,10 +121,10 @@ func (p *AWSProvider) Ensure(configs map[provider.Resource]map[string]*net.IPNet return nil } - spec.name = aws.StringValue(stack.StackName) + spec.name = aws.ToString(stack.StackName) if len(configs) == 0 { p.logger.Info("Deleting CF stack. No egress configs") - err := p.deleteCFStack(spec.name) + err := p.deleteCFStack(ctx, spec.name) if err != nil { return err } @@ -136,7 +133,7 @@ func (p *AWSProvider) Ensure(configs map[provider.Resource]map[string]*net.IPNet } // get stack template body - templateBody, err := p.getStackTemplateBody(stack) + templateBody, err := p.getStackTemplateBody(ctx, stack) if err != nil { return err } @@ -151,7 +148,7 @@ func (p *AWSProvider) Ensure(configs map[provider.Resource]map[string]*net.IPNet // update stack with new config p.logger.Infof("Updating CF stack with config: %v", configs) - err = p.updateCFStack(spec) + err = p.updateCFStack(ctx, spec) if err != nil { return errors.Wrap(err, "failed to update CF stack") } @@ -202,17 +199,17 @@ func getCIDRsFromTemplate(template string) map[string]struct{} { return cidrs } -func findTagByKey(tags []*ec2.Tag, key string) string { +func findTagByKey(tags []ec2types.Tag, key string) string { for _, t := range tags { - if aws.StringValue(t.Key) == key { - return aws.StringValue(t.Value) + if aws.ToString(t.Key) == key { + return aws.ToString(t.Value) } } return "" } -func (p *AWSProvider) generateStackSpec(configs map[provider.Resource]map[string]*net.IPNet) (*stackSpec, error) { +func (p *AWSProvider) generateStackSpec(ctx context.Context, configs map[provider.Resource]map[string]*net.IPNet) (*stackSpec, error) { spec := &stackSpec{ name: normalizeStackName(p.clusterID), timeoutInMinutes: 10, @@ -225,14 +222,14 @@ func (p *AWSProvider) generateStackSpec(configs map[provider.Resource]map[string } spec.tags = tagMapToCloudformationTags(mergeTags(p.additionalStackTags, tags)) - vpcID, err := p.findVPC() + vpcID, err := p.findVPC(ctx) if err != nil { return nil, err } spec.vpcID = vpcID // get assigned internet gateway - igw, err := p.getInternetGatewayId(spec.vpcID) + igw, err := p.getInternetGatewayId(ctx, spec.vpcID) p.logger.Debugf("%s: igw(%d)", p, len(igw)) if err != nil { return nil, err @@ -243,11 +240,11 @@ func (p *AWSProvider) generateStackSpec(configs map[provider.Resource]map[string } // get first internet gateway ID - igwID := aws.StringValue(igw[0].InternetGatewayId) + igwID := aws.ToString(igw[0].InternetGatewayId) spec.internetGatewayID = igwID // get route tables - rt, err := p.getRouteTables(spec.vpcID) + rt, err := p.getRouteTables(ctx, spec.vpcID) p.logger.Debugf("%s: rt(%d)", p, len(rt)) if err != nil { return nil, err @@ -316,7 +313,7 @@ func (p *AWSProvider) generateStackSpec(configs map[provider.Resource]map[string paramName := fmt.Sprintf("AZ%dRouteTableIDParameter", i+1) paramOrder = append(paramOrder, paramName) tableZoneIndexes[paramName] = zindex - tableID[paramName] = aws.StringValue(table.RouteTableId) + tableID[paramName] = aws.ToString(table.RouteTableId) } spec.template = p.generateTemplate(configs, paramOrder, tableZoneIndexes) @@ -324,10 +321,10 @@ func (p *AWSProvider) generateStackSpec(configs map[provider.Resource]map[string return spec, nil } -func routeTableZone(rt *ec2.RouteTable) (string, bool) { +func routeTableZone(rt ec2types.RouteTable) (string, bool) { for _, tag := range rt.Tags { - if tagDefaultAZKeyRouteTableID == aws.StringValue(tag.Key) { - return aws.StringValue(tag.Value), true + if tagDefaultAZKeyRouteTableID == aws.ToString(tag.Key) { + return aws.ToString(tag.Value), true } } @@ -344,25 +341,25 @@ func zoneIndex(zones []string, zone string) (int, bool) { return 0, false } -func (p *AWSProvider) findVPC() (string, error) { +func (p *AWSProvider) findVPC(ctx context.Context) (string, error) { // provided by the user if p.vpcID != "" { return p.vpcID, nil } - vpcs, err := p.getVpcID() + vpcs, err := p.getVpcID(ctx) p.logger.Debugf("%s: vpcs(%d)", p, len(vpcs)) if err != nil { return "", err } if len(vpcs) == 1 { - return aws.StringValue(vpcs[0].VpcId), nil + return aws.ToString(vpcs[0].VpcId), nil } for _, vpc := range vpcs { - if aws.BoolValue(vpc.IsDefault) { - return aws.StringValue(vpc.VpcId), nil + if aws.ToBool(vpc.IsDefault) { + return aws.ToString(vpc.VpcId), nil } } @@ -464,16 +461,19 @@ func (p *AWSProvider) generateTemplate( } func isDoesNotExistsErr(err error) bool { - if awsErr, ok := err.(awserr.Error); ok { - if awsErr.Code() == "ValidationError" && strings.Contains(awsErr.Message(), "does not exist") { - // we wanted to delete a stack and it does not exist (or was removed while we were waiting, we can hide the error) - return true + if smithyErr, ok := err.(*smithy.OperationError); ok { + if respErr, ok := smithyErr.Err.(*http.ResponseError); ok { + if apiErr, ok := respErr.Err.(*smithy.GenericAPIError); ok { + if apiErr.Code == "ValidationError" && strings.Contains(apiErr.Message, "does not exist") { + return true + } + } } } return false } -func (p *AWSProvider) deleteCFStack(stackName string) error { +func (p *AWSProvider) deleteCFStack(ctx context.Context, stackName string) error { if p.dry { p.logger.Debugf("%s: Stack to delete: %s", p, stackName) return nil @@ -486,14 +486,14 @@ func (p *AWSProvider) deleteCFStack(stackName string) error { EnableTerminationProtection: aws.Bool(false), } - _, err := p.cloudformation.UpdateTerminationProtection(termParams) + _, err := p.cloudformation.UpdateTerminationProtection(ctx, termParams) if err != nil { return err } } params := &cloudformation.DeleteStackInput{StackName: aws.String(stackName)} - _, err := p.cloudformation.DeleteStack(params) + _, err := p.cloudformation.DeleteStack(ctx, params) if err != nil { if isDoesNotExistsErr(err) { return nil @@ -501,7 +501,7 @@ func (p *AWSProvider) deleteCFStack(stackName string) error { return err } - ctx, cancel := context.WithTimeout(context.Background(), maxStackWaitTimeout) + ctx, cancel := context.WithTimeout(ctx, maxStackWaitTimeout) defer cancel() err = p.waitForStack(ctx, stackStatusCheckInterval, stackName) @@ -514,11 +514,11 @@ func (p *AWSProvider) deleteCFStack(stackName string) error { return nil } -func (p *AWSProvider) updateCFStack(spec *stackSpec) error { +func (p *AWSProvider) updateCFStack(ctx context.Context, spec *stackSpec) error { params := &cloudformation.UpdateStackInput{ StackName: aws.String(spec.name), Parameters: append( - []*cloudformation.Parameter{ + []cftypes.Parameter{ cfParam(parameterVPCIDParameter, spec.vpcID), cfParam(parameterInternetGatewayIDParameter, spec.internetGatewayID), }, @@ -536,58 +536,53 @@ func (p *AWSProvider) updateCFStack(spec *stackSpec) error { EnableTerminationProtection: aws.Bool(spec.stackTerminationProtection), } - _, err := p.cloudformation.UpdateTerminationProtection(termParams) + _, err := p.cloudformation.UpdateTerminationProtection(ctx, termParams) if err != nil { return err } } - _, err := p.cloudformation.UpdateStack(params) + _, err := p.cloudformation.UpdateStack(ctx, params) if err != nil { - if awsErr, ok := err.(awserr.Error); ok { - if awsErr.Code() == "AlreadyExistsException" { - err = provider.NewAlreadyExistsError(fmt.Sprintf("%s AlreadyExists", spec.name)) - } + if isDoesNotExistsErr(err) { + return provider.NewDoesNotExistError(fmt.Sprintf("Stack '%s' does not exist", spec.name)) } return err } - ctx, cancel := context.WithTimeout(context.Background(), maxStackWaitTimeout) + ctx, cancel := context.WithTimeout(ctx, maxStackWaitTimeout) defer cancel() return p.waitForStack(ctx, stackStatusCheckInterval, spec.name) } - p.logger.Debugf("%s: DRY: Stack to update: %s", p, params) - p.logger.Debugln(aws.StringValue(params.TemplateBody)) + p.logger.Debugf("%s: DRY: Stack to update: %v", p, params) + p.logger.Debugln(aws.ToString(params.TemplateBody)) return nil } -func (p *AWSProvider) createCFStack(spec *stackSpec) error { +func (p *AWSProvider) createCFStack(ctx context.Context, spec *stackSpec) error { params := &cloudformation.CreateStackInput{ StackName: aws.String(spec.name), - OnFailure: aws.String(cloudformation.OnFailureDelete), + OnFailure: cftypes.OnFailureDelete, Parameters: append( - []*cloudformation.Parameter{ + []cftypes.Parameter{ cfParam(parameterVPCIDParameter, spec.vpcID), cfParam(parameterInternetGatewayIDParameter, spec.internetGatewayID), }, routeTableParams(spec)..., ), TemplateBody: aws.String(spec.template), - TimeoutInMinutes: aws.Int64(int64(spec.timeoutInMinutes)), + TimeoutInMinutes: aws.Int32(int32(spec.timeoutInMinutes)), EnableTerminationProtection: aws.Bool(spec.stackTerminationProtection), Tags: spec.tags, } if !p.dry { - _, err := p.cloudformation.CreateStack(params) + _, err := p.cloudformation.CreateStack(ctx, params) if err != nil { - if awsErr, ok := err.(awserr.Error); ok { - if strings.Contains(awsErr.Message(), "does not exist") { - err = provider.NewDoesNotExistError(fmt.Sprintf("%s does not exist", spec.name)) - } else if awsErr.Code() == "AlreadyExistsException" { - err = provider.NewAlreadyExistsError(fmt.Sprintf("%s AlreadyExists", spec.name)) - } + var aer *cftypes.AlreadyExistsException + if errors.As(err, &aer) { + err = provider.NewAlreadyExistsError(fmt.Sprintf("%s AlreadyExists", spec.name)) } return err } @@ -595,14 +590,14 @@ func (p *AWSProvider) createCFStack(spec *stackSpec) error { defer cancel() return p.waitForStack(ctx, stackStatusCheckInterval, spec.name) } - p.logger.Debugf("%s: DRY: Stack to create: %s", p, params) - p.logger.Debugln(aws.StringValue(params.TemplateBody)) + p.logger.Debugf("%s: DRY: Stack to create: %v", p, params) + p.logger.Debugln(aws.ToString(params.TemplateBody)) return nil } -func routeTableParams(s *stackSpec) []*cloudformation.Parameter { - var params []*cloudformation.Parameter +func routeTableParams(s *stackSpec) []cftypes.Parameter { + var params []cftypes.Parameter for paramName, routeTableID := range s.tableID { params = append(params, cfParam(paramName, routeTableID)) } @@ -610,72 +605,74 @@ func routeTableParams(s *stackSpec) []*cloudformation.Parameter { return params } -func (p *AWSProvider) getStackByName(stackName string) (*cloudformation.Stack, error) { +func (p *AWSProvider) getStackByName(ctx context.Context, stackName string) (cftypes.Stack, error) { params := &cloudformation.DescribeStacksInput{ StackName: aws.String(stackName), } - resp, err := p.cloudformation.DescribeStacks(params) + resp, err := p.cloudformation.DescribeStacks(ctx, params) if err != nil { - return nil, err + return cftypes.Stack{}, err } // we expect only one stack if len(resp.Stacks) != 1 { - return nil, fmt.Errorf("unexpected response, got %d, expected 1 stack", len(resp.Stacks)) + return cftypes.Stack{}, fmt.Errorf("unexpected response, got %d, expected 1 stack", len(resp.Stacks)) } return resp.Stacks[0], nil } // getEgressStack gets the Egress stack by ClusterID tag or by static stack // name. -func (p *AWSProvider) getEgressStack() (*cloudformation.Stack, error) { +func (p *AWSProvider) getEgressStack(ctx context.Context) (cftypes.Stack, error) { tags := map[string]string{ p.clusterIDTagPrefix + p.clusterID: resourceLifecycleOwned, kubernetesApplicationTagKey: p.controllerID, } params := &cloudformation.DescribeStacksInput{} + paginator := cloudformation.NewDescribeStacksPaginator(p.cloudformation, params) + + var egressStack cftypes.Stack + for paginator.HasMorePages() { + resp, err := paginator.NextPage(ctx) + if err != nil { + return cftypes.Stack{}, err + } - var egressStack *cloudformation.Stack - err := p.cloudformation.DescribeStacksPages(params, func(resp *cloudformation.DescribeStacksOutput, lastPage bool) bool { for _, stack := range resp.Stacks { - if cloudformationHasTags(tags, stack.Tags) || aws.StringValue(stack.StackName) == staticLagacyStackName { + if cloudformationHasTags(tags, stack.Tags) || aws.ToString(stack.StackName) == staticLagacyStackName { egressStack = stack - return false + break } } - return true - }) - if err != nil { - return nil, err } return egressStack, nil } -func (p *AWSProvider) getStackTemplateBody(stack *cloudformation.Stack) (string, error) { +func (p *AWSProvider) getStackTemplateBody(ctx context.Context, stack cftypes.Stack) (string, error) { tParams := &cloudformation.GetTemplateInput{ StackName: stack.StackName, - TemplateStage: aws.String(cloudformation.TemplateStageOriginal), + TemplateStage: cftypes.TemplateStageOriginal, } - resp, err := p.cloudformation.GetTemplate(tParams) + resp, err := p.cloudformation.GetTemplate(ctx, tParams) if err != nil { return "", err } - return aws.StringValue(resp.TemplateBody), nil + return aws.ToString(resp.TemplateBody), nil } // cloudformationHasTags returns true if the expected tags are found in the // tags list. -func cloudformationHasTags(expected map[string]string, tags []*cloudformation.Tag) bool { +func cloudformationHasTags(expected map[string]string, tags []cftypes.Tag) bool { if len(expected) > len(tags) { return false } tagsMap := make(map[string]string, len(tags)) for _, tag := range tags { - tagsMap[aws.StringValue(tag.Key)] = aws.StringValue(tag.Value) + tagsMap[aws.ToString(tag.Key)] = aws.ToString(tag.Value) } for key, val := range expected { @@ -688,31 +685,31 @@ func cloudformationHasTags(expected map[string]string, tags []*cloudformation.Ta func (p *AWSProvider) waitForStack(ctx context.Context, waitTime time.Duration, stackName string) error { for { - stack, err := p.getStackByName(stackName) + stack, err := p.getStackByName(ctx, stackName) if err != nil { return err } - switch aws.StringValue(stack.StackStatus) { - case cloudformation.StackStatusUpdateComplete: + switch stack.StackStatus { + case cftypes.StackStatusUpdateComplete: return nil - case cloudformation.StackStatusCreateComplete: + case cftypes.StackStatusCreateComplete: return nil - case cloudformation.StackStatusDeleteComplete: + case cftypes.StackStatusDeleteComplete: return nil - case cloudformation.StackStatusCreateFailed: + case cftypes.StackStatusCreateFailed: return errCreateFailed - case cloudformation.StackStatusDeleteFailed: + case cftypes.StackStatusDeleteFailed: return errDeleteFailed - case cloudformation.StackStatusRollbackComplete: + case cftypes.StackStatusRollbackComplete: return errRollbackComplete - case cloudformation.StackStatusRollbackFailed: + case cftypes.StackStatusRollbackFailed: return errRollbackFailed - case cloudformation.StackStatusUpdateRollbackComplete: + case cftypes.StackStatusUpdateRollbackComplete: return errUpdateRollbackComplete - case cloudformation.StackStatusUpdateRollbackFailed: + case cftypes.StackStatusUpdateRollbackFailed: return errUpdateRollbackFailed } - p.logger.Debugf("Stack '%s' - [%s]", stackName, aws.StringValue(stack.StackStatus)) + p.logger.Debugf("Stack '%s' - [%s]", stackName, stack.StackStatus) select { case <-ctx.Done(): @@ -722,68 +719,52 @@ func (p *AWSProvider) waitForStack(ctx context.Context, waitTime time.Duration, } } -func defaultConfigProvider() client.ConfigProvider { - cfg := aws.NewConfig().WithMaxRetries(3) - cfg = cfg.WithHTTPClient(instrumented_http.NewClient(cfg.HTTPClient, nil)) - opts := session.Options{ - SharedConfigState: session.SharedConfigEnable, - Config: *cfg, - } - return session.Must(session.NewSessionWithOptions(opts)) -} - -func cfParam(key, value string) *cloudformation.Parameter { - return &cloudformation.Parameter{ +func cfParam(key, value string) cftypes.Parameter { + return cftypes.Parameter{ ParameterKey: aws.String(key), ParameterValue: aws.String(value), } } -func (p *AWSProvider) getInternetGatewayId(vpcID string) ([]*ec2.InternetGateway, error) { +func (p *AWSProvider) getInternetGatewayId(ctx context.Context, vpcID string) ([]ec2types.InternetGateway, error) { params := &ec2.DescribeInternetGatewaysInput{ - Filters: []*ec2.Filter{ + Filters: []ec2types.Filter{ { - Name: aws.String("attachment.vpc-id"), - Values: []*string{ - aws.String(vpcID), - }, + Name: aws.String("attachment.vpc-id"), + Values: []string{vpcID}, }, }, } - resp, err := p.ec2.DescribeInternetGateways(params) + resp, err := p.ec2.DescribeInternetGateways(ctx, params) if err != nil { return nil, err } return resp.InternetGateways, nil } -func (p *AWSProvider) getVpcID() ([]*ec2.Vpc, error) { +func (p *AWSProvider) getVpcID(ctx context.Context) ([]ec2types.Vpc, error) { params := &ec2.DescribeVpcsInput{} - resp, err := p.ec2.DescribeVpcs(params) + resp, err := p.ec2.DescribeVpcs(ctx, params) if err != nil { return nil, err } return resp.Vpcs, nil } -func (p *AWSProvider) getRouteTables(vpcID string) ([]*ec2.RouteTable, error) { +func (p *AWSProvider) getRouteTables(ctx context.Context, vpcID string) ([]ec2types.RouteTable, error) { params := &ec2.DescribeRouteTablesInput{ - Filters: []*ec2.Filter{ + Filters: []ec2types.Filter{ { - Name: aws.String("vpc-id"), - Values: []*string{ - aws.String(vpcID), - }, + Name: aws.String("vpc-id"), + Values: []string{vpcID}, }, { - Name: aws.String("tag:Type"), - Values: []*string{ - aws.String(tagDefaultTypeValueRouteTableID), - }, + Name: aws.String("tag:Type"), + Values: []string{tagDefaultTypeValueRouteTableID}, }, }, } - resp, err := p.ec2.DescribeRouteTables(params) + resp, err := p.ec2.DescribeRouteTables(ctx, params) if err != nil { return nil, err } @@ -800,10 +781,10 @@ func mergeTags(tags ...map[string]string) map[string]string { return mergedTags } -func tagMapToCloudformationTags(tags map[string]string) []*cloudformation.Tag { - cfTags := make([]*cloudformation.Tag, 0, len(tags)) +func tagMapToCloudformationTags(tags map[string]string) []cftypes.Tag { + cfTags := make([]cftypes.Tag, 0, len(tags)) for k, v := range tags { - tag := &cloudformation.Tag{ + tag := cftypes.Tag{ Key: aws.String(k), Value: aws.String(v), } diff --git a/provider/aws/aws_test.go b/provider/aws/aws_test.go index a5ff6ee..45da8ce 100644 --- a/provider/aws/aws_test.go +++ b/provider/aws/aws_test.go @@ -1,17 +1,18 @@ package aws import ( + "context" "errors" "fmt" "net" "sort" "testing" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/cloudformation" - "github.com/aws/aws-sdk-go/service/cloudformation/cloudformationiface" - "github.com/aws/aws-sdk-go/service/ec2" - "github.com/aws/aws-sdk-go/service/ec2/ec2iface" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/cloudformation" + cftypes "github.com/aws/aws-sdk-go-v2/service/cloudformation/types" + "github.com/aws/aws-sdk-go-v2/service/ec2" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" "github.com/szuecs/kube-static-egress-controller/provider" @@ -22,21 +23,21 @@ const ( ) type mockedReceiveMsgs struct { - ec2iface.EC2API + ec2API RespVpcs ec2.DescribeVpcsOutput RespInternetGateways ec2.DescribeInternetGatewaysOutput RespRouteTables ec2.DescribeRouteTablesOutput } -func (m mockedReceiveMsgs) DescribeVpcs(in *ec2.DescribeVpcsInput) (*ec2.DescribeVpcsOutput, error) { +func (m mockedReceiveMsgs) DescribeVpcs(_ context.Context, in *ec2.DescribeVpcsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeVpcsOutput, error) { return &m.RespVpcs, nil } -func (m mockedReceiveMsgs) DescribeInternetGateways(in *ec2.DescribeInternetGatewaysInput) (*ec2.DescribeInternetGatewaysOutput, error) { +func (m mockedReceiveMsgs) DescribeInternetGateways(_ context.Context, in *ec2.DescribeInternetGatewaysInput, optFns ...func(*ec2.Options)) (*ec2.DescribeInternetGatewaysOutput, error) { return &m.RespInternetGateways, nil } -func (m mockedReceiveMsgs) DescribeRouteTables(in *ec2.DescribeRouteTablesInput) (*ec2.DescribeRouteTablesOutput, error) { +func (m mockedReceiveMsgs) DescribeRouteTables(_ context.Context, in *ec2.DescribeRouteTablesInput, optFns ...func(*ec2.Options)) (*ec2.DescribeRouteTablesOutput, error) { return &m.RespRouteTables, nil } @@ -51,7 +52,7 @@ func TestGenerateStackSpec(t *testing.T) { additionalStackTags := map[string]string{ "foo": "bar", } - expectedTags := []*cloudformation.Tag{ + expectedTags := []cftypes.Tag{ { Key: aws.String("foo"), Value: aws.String("bar"), @@ -74,9 +75,9 @@ func TestGenerateStackSpec(t *testing.T) { netA.String(): netA, }, } - p := NewAWSProvider("cluster-x", "controller-x", true, "", clusterIDTagPrefix, natCidrBlocks, availabilityZones, false, additionalStackTags) + fakeVpcsResp := ec2.DescribeVpcsOutput{ - Vpcs: []*ec2.Vpc{ + Vpcs: []ec2types.Vpc{ { VpcId: aws.String("vpc-1111"), IsDefault: aws.Bool(true), @@ -84,18 +85,18 @@ func TestGenerateStackSpec(t *testing.T) { }, } fakeIgwResp := ec2.DescribeInternetGatewaysOutput{ - InternetGateways: []*ec2.InternetGateway{ + InternetGateways: []ec2types.InternetGateway{ { InternetGatewayId: aws.String("igw-1111")}, }, } fakeRouteTablesResp := ec2.DescribeRouteTablesOutput{ - RouteTables: []*ec2.RouteTable{ + RouteTables: []ec2types.RouteTable{ { VpcId: aws.String("vpc-1111"), - Routes: []*ec2.Route{}, + Routes: []ec2types.Route{}, RouteTableId: aws.String("rtb-1111"), - Tags: []*ec2.Tag{ + Tags: []ec2types.Tag{ { Key: aws.String(tagDefaultAZKeyRouteTableID), Value: aws.String("eu-central-1a"), @@ -105,12 +106,25 @@ func TestGenerateStackSpec(t *testing.T) { }, } - p.ec2 = mockedReceiveMsgs{ - RespVpcs: fakeVpcsResp, - RespInternetGateways: fakeIgwResp, - RespRouteTables: fakeRouteTablesResp, + p := &AWSProvider{ + clusterID: "cluster-x", + clusterIDTagPrefix: clusterIDTagPrefix, + controllerID: "controller-x", + dry: true, + vpcID: "", + natCidrBlocks: natCidrBlocks, + availabilityZones: availabilityZones, + ec2: mockedReceiveMsgs{ + RespVpcs: fakeVpcsResp, + RespInternetGateways: fakeIgwResp, + RespRouteTables: fakeRouteTablesResp, + }, + stackTerminationProtection: false, + additionalStackTags: additionalStackTags, + logger: log.WithFields(log.Fields{"provider": ProviderName}), } - stackSpec, err := p.generateStackSpec(destinationCidrBlocks) + + stackSpec, err := p.generateStackSpec(context.Background(), destinationCidrBlocks) if err != nil { t.Error("Failed to generate CloudFormation stack") } @@ -129,7 +143,7 @@ func TestGenerateStackSpec(t *testing.T) { } // sort tags to ensure stable comparison sort.Slice(stackSpec.tags, func(i, j int) bool { - return aws.StringValue(stackSpec.tags[i].Key) < aws.StringValue(stackSpec.tags[j].Key) + return aws.ToString(stackSpec.tags[i].Key) < aws.ToString(stackSpec.tags[j].Key) }) require.EqualValues(t, expectedTags, stackSpec.tags) } @@ -146,7 +160,19 @@ func TestGenerateTemplate(t *testing.T) { netA.String(): netA, }, } - p := NewAWSProvider("cluster-x", "controller-x", true, "", clusterIDTagPrefix, natCidrBlocks, availabilityZones, false, nil) + p := &AWSProvider{ + clusterID: "cluster-x", + clusterIDTagPrefix: clusterIDTagPrefix, + controllerID: "controller-x", + dry: true, + vpcID: "", + natCidrBlocks: natCidrBlocks, + availabilityZones: availabilityZones, + stackTerminationProtection: false, + additionalStackTags: nil, + logger: log.WithFields(log.Fields{"provider": ProviderName}), + } + expect := `{"AWSTemplateFormatVersion":"2010-09-09","Description":"Static Egress Stack","Parameters":{"AZ1RouteTableIDParameter":{"Type":"String","Description":"Route Table ID No 1"},"InternetGatewayIDParameter":{"Type":"String","Description":"Internet Gateway ID"},"VPCIDParameter":{"Type":"AWS::EC2::VPC::Id","Description":"VPC ID"}},"Resources":{"EIP1":{"Type":"AWS::EC2::EIP","Properties":{"Domain":"vpc"}},"NATGateway1":{"Type":"AWS::EC2::NatGateway","Properties":{"AllocationId":{"Fn::GetAtt":["EIP1","AllocationId"]},"SubnetId":{"Ref":"NATSubnet1"}}},"NATSubnet1":{"Type":"AWS::EC2::Subnet","Properties":{"AvailabilityZone":"eu-central-1a","CidrBlock":"172.31.64.0/28","Tags":[{"Key":"Name","Value":"nat-eu-central-1a"}],"VpcId":{"Ref":"VPCIDParameter"}}},"NATSubnetRoute1":{"Type":"AWS::EC2::Route","Properties":{"DestinationCidrBlock":"0.0.0.0/0","GatewayId":{"Ref":"InternetGatewayIDParameter"},"RouteTableId":{"Ref":"NATSubnetRouteTable1"}}},"NATSubnetRouteTable1":{"Type":"AWS::EC2::RouteTable","Properties":{"Tags":[{"Key":"Name","Value":"nat-eu-central-1a"}],"VpcId":{"Ref":"VPCIDParameter"}}},"NATSubnetRouteTableAssociation1":{"Type":"AWS::EC2::SubnetRouteTableAssociation","Properties":{"RouteTableId":{"Ref":"NATSubnetRouteTable1"},"SubnetId":{"Ref":"NATSubnet1"}}},"RouteToNAT1z213x95x138x236y32":{"Type":"AWS::EC2::Route","Properties":{"DestinationCidrBlock":"213.95.138.236/32","NatGatewayId":{"Ref":"NATGateway1"},"RouteTableId":{"Ref":"AZ1RouteTableIDParameter"}}}},"Outputs":{"EIP1":{"Description":"external IP of the NATGateway1","Value":{"Ref":"EIP1"}}}}` template := p.generateTemplate( destinationCidrBlocks, @@ -160,16 +186,15 @@ func TestGenerateTemplate(t *testing.T) { } type mockCloudformation struct { - cloudformationiface.CloudFormationAPI err error - stack *cloudformation.Stack + stack cftypes.Stack templateBody string } -func (cf *mockCloudformation) DescribeStacks(input *cloudformation.DescribeStacksInput) (*cloudformation.DescribeStacksOutput, error) { - if cf.stack != nil { +func (cf *mockCloudformation) DescribeStacks(_ context.Context, input *cloudformation.DescribeStacksInput, optFns ...func(*cloudformation.Options)) (*cloudformation.DescribeStacksOutput, error) { + if cf.stack.StackName != nil { return &cloudformation.DescribeStacksOutput{ - Stacks: []*cloudformation.Stack{cf.stack}, + Stacks: []cftypes.Stack{cf.stack}, }, nil } return &cloudformation.DescribeStacksOutput{ @@ -177,7 +202,7 @@ func (cf *mockCloudformation) DescribeStacks(input *cloudformation.DescribeStack }, cf.err } -func (cf *mockCloudformation) GetTemplate(input *cloudformation.GetTemplateInput) (*cloudformation.GetTemplateOutput, error) { +func (cf *mockCloudformation) GetTemplate(_ context.Context, input *cloudformation.GetTemplateInput, optFns ...func(*cloudformation.Options)) (*cloudformation.GetTemplateOutput, error) { if cf.templateBody != "" { return &cloudformation.GetTemplateOutput{ TemplateBody: aws.String(cf.templateBody), @@ -188,19 +213,10 @@ func (cf *mockCloudformation) GetTemplate(input *cloudformation.GetTemplateInput }, cf.err } -func (cf *mockCloudformation) DescribeStacksPages(input *cloudformation.DescribeStacksInput, fn func(*cloudformation.DescribeStacksOutput, bool) bool) error { - if cf.stack != nil { - fn(&cloudformation.DescribeStacksOutput{ - Stacks: []*cloudformation.Stack{cf.stack}, - }, true) - return nil - } - return cf.err -} - -func (cf *mockCloudformation) CreateStack(input *cloudformation.CreateStackInput) (*cloudformation.CreateStackOutput, error) { - cf.stack = &cloudformation.Stack{ - StackStatus: aws.String(cloudformation.StackStatusCreateComplete), +func (cf *mockCloudformation) CreateStack(_ context.Context, input *cloudformation.CreateStackInput, optFns ...func(*cloudformation.Options)) (*cloudformation.CreateStackOutput, error) { + cf.stack = cftypes.Stack{ + StackName: aws.String("stack"), + StackStatus: cftypes.StackStatusCreateComplete, Tags: input.Tags, } return &cloudformation.CreateStackOutput{ @@ -208,40 +224,42 @@ func (cf *mockCloudformation) CreateStack(input *cloudformation.CreateStackInput }, cf.err } -func (cf *mockCloudformation) UpdateStack(input *cloudformation.UpdateStackInput) (*cloudformation.UpdateStackOutput, error) { - cf.stack = &cloudformation.Stack{ - StackStatus: aws.String(cloudformation.StackStatusUpdateComplete), +func (cf *mockCloudformation) UpdateStack(_ context.Context, input *cloudformation.UpdateStackInput, optFns ...func(*cloudformation.Options)) (*cloudformation.UpdateStackOutput, error) { + cf.stack = cftypes.Stack{ + StackName: aws.String("stack"), + StackStatus: cftypes.StackStatusUpdateComplete, Tags: input.Tags, } - cf.templateBody = aws.StringValue(input.TemplateBody) + cf.templateBody = aws.ToString(input.TemplateBody) return &cloudformation.UpdateStackOutput{ StackId: aws.String(""), }, cf.err } -func (cf *mockCloudformation) UpdateTerminationProtection(*cloudformation.UpdateTerminationProtectionInput) (*cloudformation.UpdateTerminationProtectionOutput, error) { +func (cf *mockCloudformation) UpdateTerminationProtection(context.Context, *cloudformation.UpdateTerminationProtectionInput, ...func(*cloudformation.Options)) (*cloudformation.UpdateTerminationProtectionOutput, error) { return nil, cf.err } -func (cf *mockCloudformation) DeleteStack(*cloudformation.DeleteStackInput) (*cloudformation.DeleteStackOutput, error) { - cf.stack = &cloudformation.Stack{ - StackStatus: aws.String(cloudformation.StackStatusDeleteComplete), +func (cf *mockCloudformation) DeleteStack(context.Context, *cloudformation.DeleteStackInput, ...func(*cloudformation.Options)) (*cloudformation.DeleteStackOutput, error) { + cf.stack = cftypes.Stack{ + StackName: aws.String("stack"), + StackStatus: cftypes.StackStatusDeleteComplete, } return nil, cf.err } type mockEC2 struct { - ec2iface.EC2API + ec2API err error describeInternetGatewaysOutput *ec2.DescribeInternetGatewaysOutput describeRouteTables *ec2.DescribeRouteTablesOutput } -func (ec2 *mockEC2) DescribeInternetGateways(*ec2.DescribeInternetGatewaysInput) (*ec2.DescribeInternetGatewaysOutput, error) { +func (ec2 *mockEC2) DescribeInternetGateways(context.Context, *ec2.DescribeInternetGatewaysInput, ...func(*ec2.Options)) (*ec2.DescribeInternetGatewaysOutput, error) { return ec2.describeInternetGatewaysOutput, ec2.err } -func (ec2 *mockEC2) DescribeRouteTables(*ec2.DescribeRouteTablesInput) (*ec2.DescribeRouteTablesOutput, error) { +func (ec2 *mockEC2) DescribeRouteTables(context.Context, *ec2.DescribeRouteTablesInput, ...func(*ec2.Options)) (*ec2.DescribeRouteTablesOutput, error) { return ec2.describeRouteTables, ec2.err } @@ -255,7 +273,7 @@ func TestEnsure(tt *testing.T) { ec2 *mockEC2 configs map[provider.Resource]map[string]*net.IPNet success bool - expectedStack *cloudformation.Stack + expectedStack cftypes.Stack expectedCIDRsFromTemplate map[string]struct{} }{ { @@ -263,28 +281,26 @@ func TestEnsure(tt *testing.T) { cf: &mockCloudformation{ err: errors.New("failed"), }, - success: false, - expectedStack: nil, + success: false, }, { - msg: "don't do anything if the stack doesn't exist and the config is empty", - cf: &mockCloudformation{}, - success: true, - expectedStack: nil, + msg: "don't do anything if the stack doesn't exist and the config is empty", + cf: &mockCloudformation{}, + success: true, }, { msg: "create new stack if it doesn't already exists", cf: &mockCloudformation{}, ec2: &mockEC2{ describeInternetGatewaysOutput: &ec2.DescribeInternetGatewaysOutput{ - InternetGateways: []*ec2.InternetGateway{ + InternetGateways: []ec2types.InternetGateway{ { InternetGatewayId: aws.String(""), }, }, }, describeRouteTables: &ec2.DescribeRouteTablesOutput{ - RouteTables: []*ec2.RouteTable{ + RouteTables: []ec2types.RouteTable{ { RouteTableId: aws.String(""), }, @@ -300,9 +316,10 @@ func TestEnsure(tt *testing.T) { }, }, success: true, - expectedStack: &cloudformation.Stack{ - StackStatus: aws.String(cloudformation.StackStatusCreateComplete), - Tags: []*cloudformation.Tag{ + expectedStack: cftypes.Stack{ + StackName: aws.String("stack"), + StackStatus: cftypes.StackStatusCreateComplete, + Tags: []cftypes.Tag{ { Key: aws.String(clusterIDTagPrefix + "cluster-x"), Value: aws.String(resourceLifecycleOwned), @@ -317,9 +334,10 @@ func TestEnsure(tt *testing.T) { { msg: "delete stack if there are no configs", cf: &mockCloudformation{ - stack: &cloudformation.Stack{ - StackStatus: aws.String(cloudformation.StackStatusCreateComplete), - Tags: []*cloudformation.Tag{ + stack: cftypes.Stack{ + StackName: aws.String("stack"), + StackStatus: cftypes.StackStatusCreateComplete, + Tags: []cftypes.Tag{ { Key: aws.String(clusterIDTagPrefix + "cluster-x"), Value: aws.String(resourceLifecycleOwned), @@ -333,14 +351,14 @@ func TestEnsure(tt *testing.T) { }, ec2: &mockEC2{ describeInternetGatewaysOutput: &ec2.DescribeInternetGatewaysOutput{ - InternetGateways: []*ec2.InternetGateway{ + InternetGateways: []ec2types.InternetGateway{ { InternetGatewayId: aws.String(""), }, }, }, describeRouteTables: &ec2.DescribeRouteTablesOutput{ - RouteTables: []*ec2.RouteTable{ + RouteTables: []ec2types.RouteTable{ { RouteTableId: aws.String(""), }, @@ -349,16 +367,18 @@ func TestEnsure(tt *testing.T) { }, configs: nil, success: true, - expectedStack: &cloudformation.Stack{ - StackStatus: aws.String(cloudformation.StackStatusDeleteComplete), + expectedStack: cftypes.Stack{ + StackName: aws.String("stack"), + StackStatus: cftypes.StackStatusDeleteComplete, }, }, { msg: "update stack if there are changes to the configs", cf: &mockCloudformation{ - stack: &cloudformation.Stack{ - StackStatus: aws.String(cloudformation.StackStatusCreateComplete), - Tags: []*cloudformation.Tag{ + stack: cftypes.Stack{ + StackName: aws.String("stack"), + StackStatus: cftypes.StackStatusCreateComplete, + Tags: []cftypes.Tag{ { Key: aws.String(clusterIDTagPrefix + "cluster-x"), Value: aws.String(resourceLifecycleOwned), @@ -373,17 +393,17 @@ func TestEnsure(tt *testing.T) { }, ec2: &mockEC2{ describeInternetGatewaysOutput: &ec2.DescribeInternetGatewaysOutput{ - InternetGateways: []*ec2.InternetGateway{ + InternetGateways: []ec2types.InternetGateway{ { InternetGatewayId: aws.String(""), }, }, }, describeRouteTables: &ec2.DescribeRouteTablesOutput{ - RouteTables: []*ec2.RouteTable{ + RouteTables: []ec2types.RouteTable{ { RouteTableId: aws.String("foo"), - Tags: []*ec2.Tag{ + Tags: []ec2types.Tag{ { Key: aws.String("AvailabilityZone"), Value: aws.String("eu-central-1a"), @@ -403,9 +423,10 @@ func TestEnsure(tt *testing.T) { }, }, success: true, - expectedStack: &cloudformation.Stack{ - StackStatus: aws.String(cloudformation.StackStatusUpdateComplete), - Tags: []*cloudformation.Tag{ + expectedStack: cftypes.Stack{ + StackName: aws.String("stack"), + StackStatus: cftypes.StackStatusUpdateComplete, + Tags: []cftypes.Tag{ { Key: aws.String(clusterIDTagPrefix + "cluster-x"), Value: aws.String(resourceLifecycleOwned), @@ -424,9 +445,10 @@ func TestEnsure(tt *testing.T) { { msg: "correctly update 'old' stack if there are changes to the configs", cf: &mockCloudformation{ - stack: &cloudformation.Stack{ - StackStatus: aws.String(cloudformation.StackStatusCreateComplete), - Tags: []*cloudformation.Tag{ + stack: cftypes.Stack{ + StackName: aws.String("stack"), + StackStatus: cftypes.StackStatusCreateComplete, + Tags: []cftypes.Tag{ { Key: aws.String(clusterIDTagPrefix + "cluster-x"), Value: aws.String(resourceLifecycleOwned), @@ -441,17 +463,17 @@ func TestEnsure(tt *testing.T) { }, ec2: &mockEC2{ describeInternetGatewaysOutput: &ec2.DescribeInternetGatewaysOutput{ - InternetGateways: []*ec2.InternetGateway{ + InternetGateways: []ec2types.InternetGateway{ { InternetGatewayId: aws.String(""), }, }, }, describeRouteTables: &ec2.DescribeRouteTablesOutput{ - RouteTables: []*ec2.RouteTable{ + RouteTables: []ec2types.RouteTable{ { RouteTableId: aws.String("foo"), - Tags: []*ec2.Tag{{ + Tags: []ec2types.Tag{{ Key: aws.String("AvailabilityZone"), Value: aws.String("eu-central-1a"), }}, @@ -469,9 +491,10 @@ func TestEnsure(tt *testing.T) { }, }, success: true, - expectedStack: &cloudformation.Stack{ - StackStatus: aws.String(cloudformation.StackStatusUpdateComplete), - Tags: []*cloudformation.Tag{ + expectedStack: cftypes.Stack{ + StackName: aws.String("stack"), + StackStatus: cftypes.StackStatusUpdateComplete, + Tags: []cftypes.Tag{ { Key: aws.String(clusterIDTagPrefix + "cluster-x"), Value: aws.String(resourceLifecycleOwned), @@ -510,13 +533,13 @@ func TestEnsure(tt *testing.T) { logger: log.WithFields(log.Fields{"provider": ProviderName}), } - err := provider.Ensure(tc.configs) + err := provider.Ensure(context.Background(), tc.configs) if tc.success { require.NoError(t, err) - if tc.cf.stack != nil && len(tc.cf.stack.Tags) > 0 { + if tc.cf.stack.StackName != nil && len(tc.cf.stack.Tags) > 0 { // sort tags to ensure stable comparison sort.Slice(tc.cf.stack.Tags, func(i, j int) bool { - return aws.StringValue(tc.cf.stack.Tags[i].Key) < aws.StringValue(tc.cf.stack.Tags[j].Key) + return aws.ToString(tc.cf.stack.Tags[i].Key) < aws.ToString(tc.cf.stack.Tags[j].Key) }) } require.Equal(t, tc.expectedStack, tc.cf.stack) @@ -536,7 +559,7 @@ func TestCloudformationHasTags(tt *testing.T) { for _, tc := range []struct { msg string expectedTags map[string]string - tags []*cloudformation.Tag + tags []cftypes.Tag expected bool }{ { @@ -544,7 +567,7 @@ func TestCloudformationHasTags(tt *testing.T) { expectedTags: map[string]string{ "foo": "bar", }, - tags: []*cloudformation.Tag{ + tags: []cftypes.Tag{ { Key: aws.String("foo"), Value: aws.String("bar"), @@ -558,7 +581,7 @@ func TestCloudformationHasTags(tt *testing.T) { "foo": "bar", "foz": "baz", }, - tags: []*cloudformation.Tag{ + tags: []cftypes.Tag{ { Key: aws.String("foo"), Value: aws.String("bar"), @@ -571,7 +594,7 @@ func TestCloudformationHasTags(tt *testing.T) { expectedTags: map[string]string{ "foo": "baz", }, - tags: []*cloudformation.Tag{ + tags: []cftypes.Tag{ { Key: aws.String("foo"), Value: aws.String("bar"), diff --git a/provider/aws/iface.go b/provider/aws/iface.go new file mode 100644 index 0000000..8eee280 --- /dev/null +++ b/provider/aws/iface.go @@ -0,0 +1,23 @@ +package aws + +import ( + "context" + + "github.com/aws/aws-sdk-go-v2/service/cloudformation" + "github.com/aws/aws-sdk-go-v2/service/ec2" +) + +type cloudformationAPI interface { + UpdateTerminationProtection(ctx context.Context, params *cloudformation.UpdateTerminationProtectionInput, optFns ...func(*cloudformation.Options)) (*cloudformation.UpdateTerminationProtectionOutput, error) + DescribeStacks(ctx context.Context, params *cloudformation.DescribeStacksInput, optFns ...func(*cloudformation.Options)) (*cloudformation.DescribeStacksOutput, error) + CreateStack(ctx context.Context, params *cloudformation.CreateStackInput, optFns ...func(*cloudformation.Options)) (*cloudformation.CreateStackOutput, error) + UpdateStack(ctx context.Context, params *cloudformation.UpdateStackInput, optFns ...func(*cloudformation.Options)) (*cloudformation.UpdateStackOutput, error) + DeleteStack(ctx context.Context, params *cloudformation.DeleteStackInput, optFns ...func(*cloudformation.Options)) (*cloudformation.DeleteStackOutput, error) + GetTemplate(ctx context.Context, params *cloudformation.GetTemplateInput, optFns ...func(*cloudformation.Options)) (*cloudformation.GetTemplateOutput, error) +} + +type ec2API interface { + DescribeInternetGateways(ctx context.Context, params *ec2.DescribeInternetGatewaysInput, optFns ...func(*ec2.Options)) (*ec2.DescribeInternetGatewaysOutput, error) + DescribeVpcs(ctx context.Context, params *ec2.DescribeVpcsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeVpcsOutput, error) + DescribeRouteTables(context.Context, *ec2.DescribeRouteTablesInput, ...func(*ec2.Options)) (*ec2.DescribeRouteTablesOutput, error) +} diff --git a/provider/noop/noop.go b/provider/noop/noop.go index 143d136..5e2497f 100644 --- a/provider/noop/noop.go +++ b/provider/noop/noop.go @@ -1,6 +1,7 @@ package noop import ( + "context" "net" log "github.com/sirupsen/logrus" @@ -19,7 +20,7 @@ func (p NoopProvider) String() string { return ProviderName } -func (p *NoopProvider) Ensure(configs map[provider.Resource]map[string]*net.IPNet) error { +func (p *NoopProvider) Ensure(_ context.Context, configs map[provider.Resource]map[string]*net.IPNet) error { log.Infof("%s Ensure(%v)", ProviderName, configs) return nil } diff --git a/provider/provider.go b/provider/provider.go index 41baf8b..9fb4491 100644 --- a/provider/provider.go +++ b/provider/provider.go @@ -1,6 +1,9 @@ package provider -import "net" +import ( + "context" + "net" +) type Resource struct { Name string @@ -13,6 +16,6 @@ type EgressConfig struct { } type Provider interface { - Ensure(configs map[Resource]map[string]*net.IPNet) error + Ensure(ctx context.Context, configs map[Resource]map[string]*net.IPNet) error String() string }