Skip to content

Commit

Permalink
Added GetTagsForSubnet() (and Tags property to Subnet struct) i…
Browse files Browse the repository at this point in the history
…nto `aws` module (#1005)
  • Loading branch information
SphenicPaul authored Oct 11, 2021
1 parent c583606 commit 3603b38
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 14 deletions.
57 changes: 43 additions & 14 deletions modules/aws/vpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,18 @@ type Vpc struct {

// Subnet is a subnet in an availability zone.
type Subnet struct {
Id string // The ID of the Subnet
AvailabilityZone string // The Availability Zone the subnet is in
Id string // The ID of the Subnet
AvailabilityZone string // The Availability Zone the subnet is in
Tags map[string]string // The tags associated with the subnet
}

var vpcIDFilterName = "vpc-id"
var vpcResourceIdFilterName = "resource-id"
var vpcResourceTypeFilterName = "resource-type"
var vpcResourceTypeFilterValue = "vpc"
var isDefaultFilterName = "isDefault"
var isDefaultFilterValue = "true"
const vpcIDFilterName = "vpc-id"
const resourceTypeFilterName = "resource-id"
const resourceIdFilterName = "resource-type"
const vpcResourceTypeFilterValue = "vpc"
const subnetResourceTypeFilterValue = "subnet"
const isDefaultFilterName = "isDefault"
const isDefaultFilterValue = "true"

// GetDefaultVpc fetches information about the default VPC in the given region.
func GetDefaultVpc(t testing.TestingT, region string) *Vpc {
Expand All @@ -42,7 +44,7 @@ func GetDefaultVpc(t testing.TestingT, region string) *Vpc {

// GetDefaultVpcE fetches information about the default VPC in the given region.
func GetDefaultVpcE(t testing.TestingT, region string) (*Vpc, error) {
defaultVpcFilter := ec2.Filter{Name: &isDefaultFilterName, Values: []*string{&isDefaultFilterValue}}
defaultVpcFilter := ec2.Filter{Name: aws.String(isDefaultFilterName), Values: []*string{aws.String(isDefaultFilterValue)}}
vpcs, err := GetVpcsE(t, []*ec2.Filter{&defaultVpcFilter}, region)

numVpcs := len(vpcs)
Expand All @@ -62,7 +64,7 @@ func GetVpcById(t testing.TestingT, vpcId string, region string) *Vpc {

// GetVpcByIdE fetches information about a VPC with given Id in the given region.
func GetVpcByIdE(t testing.TestingT, vpcId string, region string) (*Vpc, error) {
vpcIdFilter := ec2.Filter{Name: &vpcIDFilterName, Values: []*string{&vpcId}}
vpcIdFilter := ec2.Filter{Name: aws.String(vpcIDFilterName), Values: []*string{&vpcId}}
vpcs, err := GetVpcsE(t, []*ec2.Filter{&vpcIdFilter}, region)

numVpcs := len(vpcs)
Expand Down Expand Up @@ -137,7 +139,7 @@ func GetSubnetsForVpcE(t testing.TestingT, vpcID string, region string) ([]Subne
return nil, err
}

vpcIDFilter := ec2.Filter{Name: &vpcIDFilterName, Values: []*string{&vpcID}}
vpcIDFilter := ec2.Filter{Name: aws.String(vpcIDFilterName), Values: []*string{&vpcID}}
subnetOutput, err := client.DescribeSubnets(&ec2.DescribeSubnetsInput{Filters: []*ec2.Filter{&vpcIDFilter}})
if err != nil {
return nil, err
Expand All @@ -146,7 +148,8 @@ func GetSubnetsForVpcE(t testing.TestingT, vpcID string, region string) ([]Subne
subnets := []Subnet{}

for _, ec2Subnet := range subnetOutput.Subnets {
subnet := Subnet{Id: aws.StringValue(ec2Subnet.SubnetId), AvailabilityZone: aws.StringValue(ec2Subnet.AvailabilityZone)}
subnetTags := GetTagsForSubnet(t, *ec2Subnet.SubnetId, region)
subnet := Subnet{Id: aws.StringValue(ec2Subnet.SubnetId), AvailabilityZone: aws.StringValue(ec2Subnet.AvailabilityZone), Tags: subnetTags}
subnets = append(subnets, subnet)
}

Expand All @@ -166,8 +169,8 @@ func GetTagsForVpcE(t testing.TestingT, vpcID string, region string) (map[string
client, err := NewEc2ClientE(t, region)
require.NoError(t, err)

vpcResourceTypeFilter := ec2.Filter{Name: &vpcResourceTypeFilterName, Values: []*string{&vpcResourceTypeFilterValue}}
vpcResourceIdFilter := ec2.Filter{Name: &vpcResourceIdFilterName, Values: []*string{&vpcID}}
vpcResourceTypeFilter := ec2.Filter{Name: aws.String(resourceIdFilterName), Values: []*string{aws.String(vpcResourceTypeFilterValue)}}
vpcResourceIdFilter := ec2.Filter{Name: aws.String(resourceTypeFilterName), Values: []*string{&vpcID}}
tagsOutput, err := client.DescribeTags(&ec2.DescribeTagsInput{Filters: []*ec2.Filter{&vpcResourceTypeFilter, &vpcResourceIdFilter}})
require.NoError(t, err)

Expand All @@ -179,6 +182,32 @@ func GetTagsForVpcE(t testing.TestingT, vpcID string, region string) (map[string
return tags, nil
}

// GetTagsForSubnet gets the tags for the specified subnet.
func GetTagsForSubnet(t testing.TestingT, subnetId string, region string) map[string]string {
tags, err := GetTagsForSubnetE(t, subnetId, region)
require.NoError(t, err)

return tags
}

// GetTagsForSubnetE gets the tags for the specified subnet.
func GetTagsForSubnetE(t testing.TestingT, subnetId string, region string) (map[string]string, error) {
client, err := NewEc2ClientE(t, region)
require.NoError(t, err)

subnetResourceTypeFilter := ec2.Filter{Name: aws.String(resourceIdFilterName), Values: []*string{aws.String(subnetResourceTypeFilterValue)}}
subnetResourceIdFilter := ec2.Filter{Name: aws.String(resourceTypeFilterName), Values: []*string{&subnetId}}
tagsOutput, err := client.DescribeTags(&ec2.DescribeTagsInput{Filters: []*ec2.Filter{&subnetResourceTypeFilter, &subnetResourceIdFilter}})
require.NoError(t, err)

tags := map[string]string{}
for _, tag := range tagsOutput.Tags {
tags[aws.StringValue(tag.Key)] = aws.StringValue(tag.Value)
}

return tags, nil
}

// IsPublicSubnet returns True if the subnet identified by the given id in the provided region is public.
func IsPublicSubnet(t testing.TestingT, subnetId string, region string) bool {
isPublic, err := IsPublicSubnetE(t, subnetId, region)
Expand Down
29 changes: 29 additions & 0 deletions modules/aws/vpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,35 @@ func TestGetTagsForVpc(t *testing.T) {
assert.True(t, len(tags) == len(testTags))
}

func TestGetTagsForSubnet(t *testing.T) {
t.Parallel()

region := GetRandomStableRegion(t, nil, nil)
vpc := createVpc(t, region)
defer deleteVpc(t, *vpc.VpcId, region)

routeTable := createRouteTable(t, *vpc.VpcId, region)
subnet := createSubnet(t, *vpc.VpcId, *routeTable.RouteTableId, region)

noTags := GetTagsForSubnet(t, *subnet.SubnetId, region)
assert.True(t, len(subnet.Tags) == 0)
assert.True(t, len(noTags) == 0)

testTags := make(map[string]string)
testTags["TagKey1"] = "TagValue1"
testTags["TagKey2"] = "TagValue2"

AddTagsToResource(t, region, *subnet.SubnetId, testTags)

subnetWithTags := GetSubnetsForVpc(t, *vpc.VpcId, region)[0]
tags := GetTagsForSubnet(t, *subnet.SubnetId, region)

assert.True(t, len(subnetWithTags.Tags) == len(testTags))
assert.True(t, len(tags) == len(testTags))
assert.True(t, testTags["TagKey1"] == "TagValue1")
assert.True(t, testTags["TagKey2"] == "TagValue2")
}

func createPublicRoute(t *testing.T, vpcId string, routeTableId string, region string) {
ec2Client := NewEc2Client(t, region)

Expand Down

0 comments on commit 3603b38

Please sign in to comment.