Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cloudfront Distribution: Origin Groups support #7202

Merged
merged 10 commits into from
Mar 17, 2019
Merged
146 changes: 145 additions & 1 deletion aws/cloudfront_distribution_configuration_structure.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ func expandDistributionConfig(d *schema.ResourceData) *cloudfront.DistributionCo
if v, ok := d.GetOk("viewer_certificate"); ok {
distributionConfig.ViewerCertificate = expandViewerCertificate(v.([]interface{})[0].(map[string]interface{}))
}

if v, ok := d.GetOk("origin_group"); ok {
distributionConfig.OriginGroups = expandOriginGroups(v.(*schema.Set))
}
return distributionConfig
}

Expand Down Expand Up @@ -151,6 +153,12 @@ func flattenDistributionConfig(d *schema.ResourceData, distributionConfig *cloud
return err
}
}
if *distributionConfig.OriginGroups.Quantity > 0 {
err = d.Set("origin_group", flattenOriginGroups(distributionConfig.OriginGroups))
if err != nil {
return err
}
}

return nil
}
Expand Down Expand Up @@ -595,6 +603,104 @@ func flattenOrigin(or *cloudfront.Origin) map[string]interface{} {
return m
}

func expandOriginGroups(s *schema.Set) *cloudfront.OriginGroups {
qty := 0
items := []*cloudfront.OriginGroup{}
for _, v := range s.List() {
items = append(items, expandOriginGroup(v.(map[string]interface{})))
qty++
}
return &cloudfront.OriginGroups{
Quantity: aws.Int64(int64(qty)),
Items: items,
}
}

func flattenOriginGroups(ogs *cloudfront.OriginGroups) *schema.Set {
s := []interface{}{}
for _, v := range ogs.Items {
s = append(s, flattenOriginGroup(v))
}
return schema.NewSet(originGroupHash, s)
}

func expandOriginGroup(m map[string]interface{}) *cloudfront.OriginGroup {
failoverCriteria := m["failover_criteria"].([]interface{})[0].(map[string]interface{})
members := m["member"].([]interface{})
originGroup := &cloudfront.OriginGroup{
Id: aws.String(m["origin_id"].(string)),
FailoverCriteria: expandOriginGroupFailoverCriteria(failoverCriteria),
Members: expandMembers(members),
}
return originGroup
}

func flattenOriginGroup(og *cloudfront.OriginGroup) map[string]interface{} {
m := make(map[string]interface{})
m["origin_id"] = *og.Id
if og.FailoverCriteria != nil {
m["failover_criteria"] = flattenOriginGroupFailoverCriteria(og.FailoverCriteria)
}
if og.Members != nil {
m["member"] = flattenOriginGroupMembers(og.Members)
}
return m
}

func expandOriginGroupFailoverCriteria(m map[string]interface{}) *cloudfront.OriginGroupFailoverCriteria {
failoverCriteria := &cloudfront.OriginGroupFailoverCriteria{}
if v, ok := m["status_codes"]; ok {
codes := []*int64{}
for _, code := range v.(*schema.Set).List() {
codes = append(codes, aws.Int64(int64(code.(int))))
}
failoverCriteria.StatusCodes = &cloudfront.StatusCodes{
Items: codes,
Quantity: aws.Int64(int64(len(codes))),
}
}
return failoverCriteria
}

func flattenOriginGroupFailoverCriteria(ogfc *cloudfront.OriginGroupFailoverCriteria) []interface{} {
m := make(map[string]interface{})
if ogfc.StatusCodes.Items != nil {
l := []interface{}{}
for _, i := range ogfc.StatusCodes.Items {
l = append(l, int(*i))
}
m["status_codes"] = schema.NewSet(schema.HashInt, l)
}
return []interface{}{m}
}

func expandMembers(l []interface{}) *cloudfront.OriginGroupMembers {
qty := 0
items := []*cloudfront.OriginGroupMember{}
for _, m := range l {
ogm := &cloudfront.OriginGroupMember{
OriginId: aws.String(m.(map[string]interface{})["origin_id"].(string)),
}
items = append(items, ogm)
qty++
}
return &cloudfront.OriginGroupMembers{
Quantity: aws.Int64(int64(qty)),
Items: items,
}
}

func flattenOriginGroupMembers(ogm *cloudfront.OriginGroupMembers) []interface{} {
s := []interface{}{}
for _, i := range ogm.Items {
m := map[string]interface{}{
"origin_id": *i.OriginId,
}
s = append(s, m)
}
return s
}

// Assemble the hash for the aws_cloudfront_distribution origin
// TypeSet attribute.
func originHash(v interface{}) int {
Expand All @@ -621,6 +727,44 @@ func originHash(v interface{}) int {
return hashcode.String(buf.String())
}

// Assemble the hash for the aws_cloudfront_distribution origin group
// TypeSet attribute.
func originGroupHash(v interface{}) int {
var buf bytes.Buffer
m := v.(map[string]interface{})
buf.WriteString(fmt.Sprintf("%s-", m["origin_id"].(string)))
if v, ok := m["failover_criteria"]; ok {
if l := v.([]interface{}); len(l) > 0 {
buf.WriteString(fmt.Sprintf("%d-", failoverCriteriaHash(l[0])))
}
}
if v, ok := m["member"]; ok {
if members := v.([]interface{}); len(members) > 0 {
for _, member := range members {
buf.WriteString(fmt.Sprintf("%d-", memberHash(member)))
}
}
}
return hashcode.String(buf.String())
}

func memberHash(v interface{}) int {
var buf bytes.Buffer
buf.WriteString(fmt.Sprintf("%s-", v.(map[string]interface{})["origin_id"]))
return hashcode.String(buf.String())
}

func failoverCriteriaHash(v interface{}) int {
var buf bytes.Buffer
m := v.(map[string]interface{})
if v, ok := m["status_codes"]; ok {
for _, w := range v.(*schema.Set).List() {
buf.WriteString(fmt.Sprintf("%d-", w))
}
}
return hashcode.String(buf.String())
}

func expandCustomHeaders(s *schema.Set) *cloudfront.CustomHeaders {
qty := 0
items := []*cloudfront.OriginCustomHeader{}
Expand Down
73 changes: 73 additions & 0 deletions aws/cloudfront_distribution_configuration_structure_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,32 @@ func multiOriginConf() *schema.Set {
return schema.NewSet(originHash, []interface{}{originWithCustomConf(), originWithS3Conf()})
}

func originGroupMembers() []interface{} {
return []interface{}{map[string]interface{}{
"origin_id": "S3origin",
}, map[string]interface{}{
"origin_id": "S3failover",
}}
}

func failoverStatusCodes() map[string]interface{} {
return map[string]interface{}{
"status_codes": schema.NewSet(schema.HashInt, []interface{}{503, 504}),
}
}

func originGroupConf() map[string]interface{} {
return map[string]interface{}{
"origin_id": "groupS3",
"failover_criteria": []interface{}{failoverStatusCodes()},
"member": originGroupMembers(),
}
}

func originGroupsConf() *schema.Set {
return schema.NewSet(originGroupHash, []interface{}{originGroupConf()})
}

func geoRestrictionWhitelistConf() map[string]interface{} {
return map[string]interface{}{
"restriction_type": "whitelist",
Expand Down Expand Up @@ -540,6 +566,53 @@ func TestCloudFrontStructure_flattenOrigins(t *testing.T) {
}
}

func TestCloudFrontStructure_expandOriginGroups(t *testing.T) {
in := originGroupsConf()
groups := expandOriginGroups(in)

if *groups.Quantity != 1 {
t.Fatalf("Expected origin group quantity to be %v, got %v", 1, *groups.Quantity)
}
originGroup := groups.Items[0]
if *originGroup.Id != "groupS3" {
t.Fatalf("Expected origin group id to be %v, got %v", "groupS3", *originGroup.Id)
}
if *originGroup.FailoverCriteria.StatusCodes.Quantity != 2 {
t.Fatalf("Expected 2 origin group status codes, got %v", *originGroup.FailoverCriteria.StatusCodes.Quantity)
}
statusCodes := originGroup.FailoverCriteria.StatusCodes.Items
for _, code := range statusCodes {
if *code != 503 && *code != 504 {
t.Fatalf("Expected origin group failover status code to either 503 or 504 got %v", *code)
}
}

if *originGroup.Members.Quantity > 2 {
t.Fatalf("Expected origin group member quantity to be 2, got %v", *originGroup.Members.Quantity)
}

members := originGroup.Members.Items
if len(members) > 2 {
t.Fatalf("Expected 2 origin group members, got %v", len(members))
}
for _, member := range members {
if *member.OriginId != "S3failover" && *member.OriginId != "S3origin" {
t.Fatalf("Expected origin group member to either S3failover or s3origin got %v", *member.OriginId)
}
}
}

func TestCloudFrontStructure_flattenOriginGroups(t *testing.T) {
in := originGroupsConf()
groups := expandOriginGroups(in)
out := flattenOriginGroups(groups)
diff := in.Difference(out)

if len(diff.List()) > 0 {
t.Fatalf("Expected out to be %v, got %v, diff: %v", in, out, diff)
}
}

func TestCloudFrontStructure_expandOrigin(t *testing.T) {
data := originWithCustomConf()
or := expandOrigin(data)
Expand Down
41 changes: 41 additions & 0 deletions aws/resource_aws_cloudfront_distribution.go
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,47 @@ func resourceAwsCloudFrontDistribution() *schema.Resource {
},
},
},
"origin_group": {
Type: schema.TypeSet,
Optional: true,
Set: originGroupHash,
Elem: &schema.Resource{
Schema: map[string]*schema.Schema{
"origin_id": {
Type: schema.TypeString,
Required: true,
ValidateFunc: validation.NoZeroValues,
},
"failover_criteria": {
Type: schema.TypeList,
Required: true,
MaxItems: 1,
Elem: &schema.Resource{
Schema: map[string]*schema.Schema{
"status_codes": {
Type: schema.TypeSet,
Required: true,
Elem: &schema.Schema{Type: schema.TypeInt},
},
},
},
},
"member": {
Type: schema.TypeList,
Required: true,
MinItems: 2,
Elem: &schema.Resource{
Schema: map[string]*schema.Schema{
"origin_id": {
Type: schema.TypeString,
Required: true,
},
},
},
},
},
},
},
"origin": {
Type: schema.TypeSet,
Required: true,
Expand Down
Loading