Skip to content

Commit

Permalink
Use AWSMachinePool Spec subnet filters to fetch subnet IDs
Browse files Browse the repository at this point in the history
Co-authored-by: calvix
Co-authored-by: Shivani Singhal <[email protected]>
  • Loading branch information
shivi28 committed Mar 8, 2022
1 parent 8da76c2 commit 10af88c
Show file tree
Hide file tree
Showing 5 changed files with 162 additions and 44 deletions.
19 changes: 19 additions & 0 deletions exp/api/v1beta1/awsmachinepool_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,23 @@ func (r *AWSMachinePool) validateRootVolume() field.ErrorList {
return allErrs
}

func (r *AWSMachinePool) validateSubnets() field.ErrorList {
var allErrs field.ErrorList

if r.Spec.Subnets == nil {
return allErrs
}

for _, subnet := range r.Spec.Subnets {
if subnet.ID != nil && subnet.Filters != nil {
allErrs = append(allErrs, field.Forbidden(field.NewPath("spec.subnets.filters"), "providing either subnet ID or filter is supported, should not provide both"))
break
}
}

return allErrs
}

// ValidateCreate will do any extra validation when creating a AWSMachinePool.
func (r *AWSMachinePool) ValidateCreate() error {
log.Info("AWSMachinePool validate create", "name", r.Name)
Expand All @@ -90,6 +107,7 @@ func (r *AWSMachinePool) ValidateCreate() error {
allErrs = append(allErrs, r.validateDefaultCoolDown()...)
allErrs = append(allErrs, r.validateRootVolume()...)
allErrs = append(allErrs, r.Spec.AdditionalTags.Validate()...)
allErrs = append(allErrs, r.validateSubnets()...)

if len(allErrs) == 0 {
return nil
Expand All @@ -108,6 +126,7 @@ func (r *AWSMachinePool) ValidateUpdate(old runtime.Object) error {

allErrs = append(allErrs, r.validateDefaultCoolDown()...)
allErrs = append(allErrs, r.Spec.AdditionalTags.Validate()...)
allErrs = append(allErrs, r.validateSubnets()...)

if len(allErrs) == 0 {
return nil
Expand Down
8 changes: 1 addition & 7 deletions pkg/cloud/scope/machinepool.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (
"fmt"
"strings"

"github.com/aws/aws-sdk-go/aws"
"github.com/go-logr/logr"
"github.com/pkg/errors"
corev1 "k8s.io/api/core/v1"
Expand Down Expand Up @@ -222,12 +221,7 @@ func (m *MachinePoolScope) IsEKSManaged() bool {
}

// SubnetIDs returns the machine pool subnet IDs.
func (m *MachinePoolScope) SubnetIDs() ([]string, error) {
subnetIDs := make([]string, len(m.AWSMachinePool.Spec.Subnets))
for i, v := range m.AWSMachinePool.Spec.Subnets {
subnetIDs[i] = aws.StringValue(v.ID)
}

func (m *MachinePoolScope) SubnetIDs(subnetIDs []string) ([]string, error) {
strategy, err := newDefaultSubnetPlacementStrategy(&m.Logger)
if err != nil {
return subnetIDs, fmt.Errorf("getting subnet placement strategy: %w", err)
Expand Down
50 changes: 40 additions & 10 deletions pkg/cloud/services/autoscaling/autoscalinggroup.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/autoscaling"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/pkg/errors"
"k8s.io/utils/pointer"

Expand Down Expand Up @@ -144,7 +145,7 @@ func (s *Service) GetASGByName(scope *scope.MachinePoolScope) (*expinfrav1.AutoS

// CreateASG runs an autoscaling group.
func (s *Service) CreateASG(scope *scope.MachinePoolScope) (*expinfrav1.AutoScalingGroup, error) {
subnets, err := scope.SubnetIDs()
subnets, err := s.SubnetIDs(scope)
if err != nil {
return nil, fmt.Errorf("getting subnets for ASG: %w", err)
}
Expand Down Expand Up @@ -267,15 +268,9 @@ func (s *Service) DeleteASG(name string) error {

// UpdateASG will update the ASG of a service.
func (s *Service) UpdateASG(scope *scope.MachinePoolScope) error {
subnetIDs := make([]string, len(scope.AWSMachinePool.Spec.Subnets))
for i, v := range scope.AWSMachinePool.Spec.Subnets {
subnetIDs[i] = aws.StringValue(v.ID)
}

if len(subnetIDs) == 0 {
for _, subnet := range scope.InfraCluster.Subnets() {
subnetIDs = append(subnetIDs, subnet.ID)
}
subnetIDs, err := s.SubnetIDs(scope)
if err != nil {
return fmt.Errorf("getting subnets for ASG: %w", err)
}

input := &autoscaling.UpdateAutoScalingGroupInput{
Expand Down Expand Up @@ -465,3 +460,38 @@ func mapToTags(input map[string]string, resourceID *string) []*autoscaling.Tag {
}
return tags
}

// SubnetIDs return subnet IDs of a AWSMachinePool based on given subnetIDs and filters.
func (s *Service) SubnetIDs(scope *scope.MachinePoolScope) ([]string, error) {
subnetIDs := make([]string, 0)
var inputFilters = make([]*ec2.Filter, 0)

for _, subnet := range scope.AWSMachinePool.Spec.Subnets {
switch {
case subnet.ID != nil:
subnetIDs = append(subnetIDs, aws.StringValue(subnet.ID))
case subnet.Filters != nil:
for _, eachFilter := range subnet.Filters {
inputFilters = append(inputFilters, &ec2.Filter{
Name: aws.String(eachFilter.Name),
Values: aws.StringSlice(eachFilter.Values),
})
}
}
}

if len(inputFilters) > 0 {
out, err := s.EC2Client.DescribeSubnets(&ec2.DescribeSubnetsInput{
Filters: inputFilters,
})
if err != nil {
return nil, err
}

for _, subnet := range out.Subnets {
subnetIDs = append(subnetIDs, *subnet.SubnetId)
}
}

return scope.SubnetIDs(subnetIDs)
}
Loading

0 comments on commit 10af88c

Please sign in to comment.