diff --git a/modules/aws/vpc.go b/modules/aws/vpc.go index d4a56638f..096b6cab8 100644 --- a/modules/aws/vpc.go +++ b/modules/aws/vpc.go @@ -22,16 +22,18 @@ type Vpc struct { // Subnet is a subnet in an availability zone. type Subnet struct { - Id string // The ID of the Subnet - AvailabilityZone string // The Availability Zone the subnet is in + Id string // The ID of the Subnet + AvailabilityZone string // The Availability Zone the subnet is in + Tags map[string]string // The tags associated with the subnet } -var vpcIDFilterName = "vpc-id" -var vpcResourceIdFilterName = "resource-id" -var vpcResourceTypeFilterName = "resource-type" -var vpcResourceTypeFilterValue = "vpc" -var isDefaultFilterName = "isDefault" -var isDefaultFilterValue = "true" +const vpcIDFilterName = "vpc-id" +const resourceTypeFilterName = "resource-id" +const resourceIdFilterName = "resource-type" +const vpcResourceTypeFilterValue = "vpc" +const subnetResourceTypeFilterValue = "subnet" +const isDefaultFilterName = "isDefault" +const isDefaultFilterValue = "true" // GetDefaultVpc fetches information about the default VPC in the given region. func GetDefaultVpc(t testing.TestingT, region string) *Vpc { @@ -42,7 +44,7 @@ func GetDefaultVpc(t testing.TestingT, region string) *Vpc { // GetDefaultVpcE fetches information about the default VPC in the given region. func GetDefaultVpcE(t testing.TestingT, region string) (*Vpc, error) { - defaultVpcFilter := ec2.Filter{Name: &isDefaultFilterName, Values: []*string{&isDefaultFilterValue}} + defaultVpcFilter := ec2.Filter{Name: aws.String(isDefaultFilterName), Values: []*string{aws.String(isDefaultFilterValue)}} vpcs, err := GetVpcsE(t, []*ec2.Filter{&defaultVpcFilter}, region) numVpcs := len(vpcs) @@ -62,7 +64,7 @@ func GetVpcById(t testing.TestingT, vpcId string, region string) *Vpc { // GetVpcByIdE fetches information about a VPC with given Id in the given region. func GetVpcByIdE(t testing.TestingT, vpcId string, region string) (*Vpc, error) { - vpcIdFilter := ec2.Filter{Name: &vpcIDFilterName, Values: []*string{&vpcId}} + vpcIdFilter := ec2.Filter{Name: aws.String(vpcIDFilterName), Values: []*string{&vpcId}} vpcs, err := GetVpcsE(t, []*ec2.Filter{&vpcIdFilter}, region) numVpcs := len(vpcs) @@ -137,7 +139,7 @@ func GetSubnetsForVpcE(t testing.TestingT, vpcID string, region string) ([]Subne return nil, err } - vpcIDFilter := ec2.Filter{Name: &vpcIDFilterName, Values: []*string{&vpcID}} + vpcIDFilter := ec2.Filter{Name: aws.String(vpcIDFilterName), Values: []*string{&vpcID}} subnetOutput, err := client.DescribeSubnets(&ec2.DescribeSubnetsInput{Filters: []*ec2.Filter{&vpcIDFilter}}) if err != nil { return nil, err @@ -146,7 +148,8 @@ func GetSubnetsForVpcE(t testing.TestingT, vpcID string, region string) ([]Subne subnets := []Subnet{} for _, ec2Subnet := range subnetOutput.Subnets { - subnet := Subnet{Id: aws.StringValue(ec2Subnet.SubnetId), AvailabilityZone: aws.StringValue(ec2Subnet.AvailabilityZone)} + subnetTags := GetTagsForSubnet(t, *ec2Subnet.SubnetId, region) + subnet := Subnet{Id: aws.StringValue(ec2Subnet.SubnetId), AvailabilityZone: aws.StringValue(ec2Subnet.AvailabilityZone), Tags: subnetTags} subnets = append(subnets, subnet) } @@ -166,8 +169,8 @@ func GetTagsForVpcE(t testing.TestingT, vpcID string, region string) (map[string client, err := NewEc2ClientE(t, region) require.NoError(t, err) - vpcResourceTypeFilter := ec2.Filter{Name: &vpcResourceTypeFilterName, Values: []*string{&vpcResourceTypeFilterValue}} - vpcResourceIdFilter := ec2.Filter{Name: &vpcResourceIdFilterName, Values: []*string{&vpcID}} + vpcResourceTypeFilter := ec2.Filter{Name: aws.String(resourceIdFilterName), Values: []*string{aws.String(vpcResourceTypeFilterValue)}} + vpcResourceIdFilter := ec2.Filter{Name: aws.String(resourceTypeFilterName), Values: []*string{&vpcID}} tagsOutput, err := client.DescribeTags(&ec2.DescribeTagsInput{Filters: []*ec2.Filter{&vpcResourceTypeFilter, &vpcResourceIdFilter}}) require.NoError(t, err) @@ -179,6 +182,32 @@ func GetTagsForVpcE(t testing.TestingT, vpcID string, region string) (map[string return tags, nil } +// GetTagsForSubnet gets the tags for the specified subnet. +func GetTagsForSubnet(t testing.TestingT, subnetId string, region string) map[string]string { + tags, err := GetTagsForSubnetE(t, subnetId, region) + require.NoError(t, err) + + return tags +} + +// GetTagsForSubnetE gets the tags for the specified subnet. +func GetTagsForSubnetE(t testing.TestingT, subnetId string, region string) (map[string]string, error) { + client, err := NewEc2ClientE(t, region) + require.NoError(t, err) + + subnetResourceTypeFilter := ec2.Filter{Name: aws.String(resourceIdFilterName), Values: []*string{aws.String(subnetResourceTypeFilterValue)}} + subnetResourceIdFilter := ec2.Filter{Name: aws.String(resourceTypeFilterName), Values: []*string{&subnetId}} + tagsOutput, err := client.DescribeTags(&ec2.DescribeTagsInput{Filters: []*ec2.Filter{&subnetResourceTypeFilter, &subnetResourceIdFilter}}) + require.NoError(t, err) + + tags := map[string]string{} + for _, tag := range tagsOutput.Tags { + tags[aws.StringValue(tag.Key)] = aws.StringValue(tag.Value) + } + + return tags, nil +} + // IsPublicSubnet returns True if the subnet identified by the given id in the provided region is public. func IsPublicSubnet(t testing.TestingT, subnetId string, region string) bool { isPublic, err := IsPublicSubnetE(t, subnetId, region) diff --git a/modules/aws/vpc_test.go b/modules/aws/vpc_test.go index efa9fa739..c4651cb3b 100644 --- a/modules/aws/vpc_test.go +++ b/modules/aws/vpc_test.go @@ -99,6 +99,35 @@ func TestGetTagsForVpc(t *testing.T) { assert.True(t, len(tags) == len(testTags)) } +func TestGetTagsForSubnet(t *testing.T) { + t.Parallel() + + region := GetRandomStableRegion(t, nil, nil) + vpc := createVpc(t, region) + defer deleteVpc(t, *vpc.VpcId, region) + + routeTable := createRouteTable(t, *vpc.VpcId, region) + subnet := createSubnet(t, *vpc.VpcId, *routeTable.RouteTableId, region) + + noTags := GetTagsForSubnet(t, *subnet.SubnetId, region) + assert.True(t, len(subnet.Tags) == 0) + assert.True(t, len(noTags) == 0) + + testTags := make(map[string]string) + testTags["TagKey1"] = "TagValue1" + testTags["TagKey2"] = "TagValue2" + + AddTagsToResource(t, region, *subnet.SubnetId, testTags) + + subnetWithTags := GetSubnetsForVpc(t, *vpc.VpcId, region)[0] + tags := GetTagsForSubnet(t, *subnet.SubnetId, region) + + assert.True(t, len(subnetWithTags.Tags) == len(testTags)) + assert.True(t, len(tags) == len(testTags)) + assert.True(t, testTags["TagKey1"] == "TagValue1") + assert.True(t, testTags["TagKey2"] == "TagValue2") +} + func createPublicRoute(t *testing.T, vpcId string, routeTableId string, region string) { ec2Client := NewEc2Client(t, region)