Skip to content

Commit

Permalink
Merge pull request #1248 from gruntwork-io/default_az_subnets
Browse files Browse the repository at this point in the history
Create a function to extract default az subnets
  • Loading branch information
james03160927 authored Feb 16, 2023
2 parents 9509a64 + 7074ff3 commit 81b5820
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 5 deletions.
31 changes: 26 additions & 5 deletions modules/aws/vpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ type Subnet struct {
}

const vpcIDFilterName = "vpc-id"
const defaultForAzFilterName = "default-for-az"
const resourceTypeFilterName = "resource-id"
const resourceIdFilterName = "resource-type"
const vpcResourceTypeFilterValue = "vpc"
Expand Down Expand Up @@ -93,7 +94,8 @@ func GetVpcsE(t testing.TestingT, filters []*ec2.Filter, region string) ([]*Vpc,
retVal := make([]*Vpc, numVpcs)

for i, vpc := range vpcs.Vpcs {
subnets, err := GetSubnetsForVpcE(t, aws.StringValue(vpc.VpcId), region)
vpcIdFilter := generateVpcIdFilter(aws.StringValue(vpc.VpcId))
subnets, err := GetSubnetsForVpcE(t, region, []*ec2.Filter{&vpcIdFilter})
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -127,22 +129,41 @@ func FindVpcName(vpc *ec2.Vpc) string {

// GetSubnetsForVpc gets the subnets in the specified VPC.
func GetSubnetsForVpc(t testing.TestingT, vpcID string, region string) []Subnet {
subnets, err := GetSubnetsForVpcE(t, vpcID, region)
vpcIDFilter := generateVpcIdFilter(vpcID)
subnets, err := GetSubnetsForVpcE(t, region, []*ec2.Filter{&vpcIDFilter})
if err != nil {
t.Fatal(err)
}
return subnets
}

// GetAzDefaultSubnetsForVpc gets the default az subnets in the specified VPC.
func GetAzDefaultSubnetsForVpc(t testing.TestingT, vpcID string, region string) []Subnet {
vpcIDFilter := generateVpcIdFilter(vpcID)
defaultForAzFilter := ec2.Filter{
Name: aws.String(defaultForAzFilterName),
Values: []*string{aws.String("true")},
}
subnets, err := GetSubnetsForVpcE(t, region, []*ec2.Filter{&vpcIDFilter, &defaultForAzFilter})
if err != nil {
t.Fatal(err)
}
return subnets
}

// generateVpcIdFilter is a helper method to generate vpc id filter
func generateVpcIdFilter(vpcID string) ec2.Filter {
return ec2.Filter{Name: aws.String(vpcIDFilterName), Values: []*string{&vpcID}}
}

// GetSubnetsForVpcE gets the subnets in the specified VPC.
func GetSubnetsForVpcE(t testing.TestingT, vpcID string, region string) ([]Subnet, error) {
func GetSubnetsForVpcE(t testing.TestingT, region string, filters []*ec2.Filter) ([]Subnet, error) {
client, err := NewEc2ClientE(t, region)
if err != nil {
return nil, err
}

vpcIDFilter := ec2.Filter{Name: aws.String(vpcIDFilterName), Values: []*string{&vpcID}}
subnetOutput, err := client.DescribeSubnets(&ec2.DescribeSubnetsInput{Filters: []*ec2.Filter{&vpcIDFilter}})
subnetOutput, err := client.DescribeSubnets(&ec2.DescribeSubnetsInput{Filters: filters})
if err != nil {
return nil, err
}
Expand Down
12 changes: 12 additions & 0 deletions modules/aws/vpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,18 @@ func TestGetTagsForSubnet(t *testing.T) {
assert.True(t, testTags["TagKey2"] == "TagValue2")
}

func TestGetDefaultAzSubnets(t *testing.T) {
t.Parallel()

region := GetRandomStableRegion(t, nil, nil)
vpc := GetDefaultVpc(t, region)

// Note: cannot know exact list of default azs aheard of time, but we know that
//it must be greater than 0 for default vpc.
subnets := GetAzDefaultSubnetsForVpc(t, vpc.Id, region)
assert.NotZero(t, len(subnets))
}

func createPublicRoute(t *testing.T, vpcId string, routeTableId string, region string) {
ec2Client := NewEc2Client(t, region)

Expand Down

0 comments on commit 81b5820

Please sign in to comment.