Skip to content

Commit

Permalink
fix: add test cases for update repl schedule on failover group (#1578)
Browse files Browse the repository at this point in the history
* add test cases for update repl schedule

* linting
  • Loading branch information
sfc-gh-swinkler authored Feb 27, 2023
1 parent 368dc8f commit ab638f0
Show file tree
Hide file tree
Showing 3 changed files with 236 additions and 95 deletions.
192 changes: 114 additions & 78 deletions pkg/resources/failover_group.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ import (
"errors"
"fmt"
"log"
"strconv"
"strings"

"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"

"golang.org/x/exp/slices"

"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers"
"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/snowflake"
)

Expand All @@ -20,6 +22,9 @@ var failoverGroupSchema = map[string]*schema.Schema{
Required: true,
ForceNew: true,
Description: "Specifies the identifier for the failover group. The identifier must start with an alphabetic character and cannot contain spaces or special characters unless the identifier string is enclosed in double quotes (e.g. \"My object\"). Identifiers enclosed in double quotes are also case-sensitive.",
DiffSuppressFunc: func(k, old, new string, d *schema.ResourceData) bool {
return strings.EqualFold(old, new)
},
},
"object_types": {
Type: schema.TypeSet,
Expand Down Expand Up @@ -255,8 +260,7 @@ func CreateFailoverGroup(d *schema.ResourceData, meta interface{}) error {
}

d.SetId(name)

return ReadFailoverGroup(d, meta)
return nil
}

// ReadFailoverGroup implements schema.ReadFunc.
Expand All @@ -276,92 +280,121 @@ func ReadFailoverGroup(d *schema.ResourceData, meta interface{}) error {
return fmt.Errorf("error listing failover groups err = %w", err)
}

found := false
var failoverGroup snowflake.FailoverGroup
// find the failover group we are looking for by matching the name
for _, fg := range failoverGroups {
if fg.Name.String == name && fg.AccountLocator.String == accountLocator {
found = true
if err := d.Set("name", fg.Name.String); err != nil {
if strings.EqualFold(fg.Name.String, name) && strings.EqualFold(fg.AccountLocator.String, accountLocator) {
failoverGroup = fg
}
}

if failoverGroup.Name.String == "" {
log.Printf("[DEBUG] failover group (%v) not found when listing all failover groups in account", name)
d.SetId("")
return nil
}

if err := d.Set("name", failoverGroup.Name.String); err != nil {
return err
}
// if the failover group is created from a replica, then we do not want to get the other values
if _, ok := d.GetOk("from_replica"); ok {
log.Printf("[DEBUG] failover group %v is created from a replica, rest of values are computed\n", name)
return nil
}

replicationSchedule := failoverGroup.ReplicationSchedule.String
if replicationSchedule != "" {
if strings.Contains(replicationSchedule, "MINUTE") {
interval, err := strconv.Atoi(strings.TrimSuffix(replicationSchedule, " MINUTE"))
if err != nil {
return err
}
// if the failover group is created from a replica, then we do not want to get the other values
if _, ok := d.GetOk("from_replica"); ok {
log.Printf("[DEBUG] failover group %v is created from a replica, rest of values are computed\n", name)
return nil
err = d.Set("replication_schedule", []interface{}{
map[string]interface{}{
"interval": interval,
},
})
if err != nil {
return err
}

ots := strings.Split(fg.ObjectTypes.String, ",")
var objectTypes []string
for _, v := range ots {
objectType := strings.TrimSpace(v)
if objectType == "" {
continue
}
objectTypes = append(objectTypes, objectType)
} else {
repScheduleParts := strings.Split(replicationSchedule, " ")
timeZone := repScheduleParts[len(repScheduleParts)-1]
expression := strings.TrimSuffix(strings.TrimPrefix(replicationSchedule, "USING CRON "), " "+timeZone)
err = d.Set("replication_schedule", []interface{}{
map[string]interface{}{
"cron": []interface{}{
map[string]interface{}{
"expression": expression,
"time_zone": timeZone,
},
},
},
})
if err != nil {
return err
}
}
}

// this is basically a hack to get around the fact that the API returns the object types in a different order than what is set
// this logic could also be put in the diff suppress function, but I think it is better to do it here.
currentObjectTypeList := d.Get("object_types").(*schema.Set).List()
if len(currentObjectTypeList) != len(objectTypes) {
log.Printf("[DEBUG] object types are different, current: %v, new: %v", currentObjectTypeList, objectTypes)
if err := d.Set("object_types", objectTypes); err != nil {
return err
}
}
objectTypes := helpers.SplitStringToSlice(failoverGroup.ObjectTypes.String, ",")

for _, v := range currentObjectTypeList {
if !slices.Contains(objectTypes, v.(string)) {
log.Printf("[DEBUG] object types are different, current: %v, new: %v", currentObjectTypeList, objectTypes)
if err := d.Set("object_types", objectTypes); err != nil {
return err
}
break
}
}
// this is basically a hack to get around the fact that the API returns the object types in a different order than what is set
// this logic could also be put in the diff suppress function, but I think it is better to do it here.
currentObjectTypeList := d.Get("object_types").(*schema.Set).List()
if len(currentObjectTypeList) != len(objectTypes) {
log.Printf("[DEBUG] object types are different, current: %v, new: %v", currentObjectTypeList, objectTypes)
if err := d.Set("object_types", objectTypes); err != nil {
return err
}
}

allowedIntegrationTypes := fg.AllowedIntegrationTypes.String
if allowedIntegrationTypes != "" {
aits := strings.Split(allowedIntegrationTypes, ",")
var allowedIntegrationTypes []interface{}
for _, v := range aits {
allowedIntegrationType := strings.TrimSpace(v)
if allowedIntegrationType == "" {
continue
}
if allowedIntegrationType == "SECURITY" {
allowedIntegrationType = "SECURITY INTEGRATIONS"
}
allowedIntegrationTypes = append(allowedIntegrationTypes, allowedIntegrationType)
}
allowedIntegrationTypesSet := schema.NewSet(schema.HashString, allowedIntegrationTypes)
if err := d.Set("allowed_integration_types", allowedIntegrationTypesSet); err != nil {
return err
}
for _, v := range currentObjectTypeList {
if !slices.Contains(objectTypes, v.(string)) {
log.Printf("[DEBUG] object types are different, current: %v, new: %v", currentObjectTypeList, objectTypes)
if err := d.Set("object_types", objectTypes); err != nil {
return err
}
break
}
}

allowedAccounts := fg.AllowedAccounts.String
if allowedAccounts != "" {
aa := strings.Split(allowedAccounts, ",")
var allowedAccounts []interface{}
for _, v := range aa {
allowedAccount := strings.TrimSpace(v)
if allowedAccount == "" {
continue
}
allowedAccounts = append(allowedAccounts, allowedAccount)
}
allowedAccountsSet := schema.NewSet(schema.HashString, allowedAccounts)
if err := d.Set("allowed_accounts", allowedAccountsSet); err != nil {
return err
}
allowedIntegrationTypes := failoverGroup.AllowedIntegrationTypes.String
if allowedIntegrationTypes != "" {
aits := strings.Split(allowedIntegrationTypes, ",")
var allowedIntegrationTypes []interface{}
for _, v := range aits {
allowedIntegrationType := strings.TrimSpace(v)
if allowedIntegrationType == "" {
continue
}
if allowedIntegrationType == "SECURITY" {
allowedIntegrationType = "SECURITY INTEGRATIONS"
}
allowedIntegrationTypes = append(allowedIntegrationTypes, allowedIntegrationType)
}
allowedIntegrationTypesSet := schema.NewSet(schema.HashString, allowedIntegrationTypes)
if err := d.Set("allowed_integration_types", allowedIntegrationTypesSet); err != nil {
return err
}
}

if !found {
log.Printf("[DEBUG] failover group (%v) not found when listing all failover groups in account", name)
d.SetId("")
return nil
allowedAccounts := failoverGroup.AllowedAccounts.String
if allowedAccounts != "" {
aa := strings.Split(allowedAccounts, ",")
var allowedAccounts []interface{}
for _, v := range aa {
allowedAccount := strings.TrimSpace(v)
if allowedAccount == "" {
continue
}
allowedAccounts = append(allowedAccounts, allowedAccount)
}
allowedAccountsSet := schema.NewSet(schema.HashString, allowedAccounts)
if err := d.Set("allowed_accounts", allowedAccountsSet); err != nil {
return err
}
}

allowedDatabases, err := snowflake.ShowDatabasesInFailoverGroup(name, db)
Expand Down Expand Up @@ -564,8 +597,10 @@ func UpdateFailoverGroup(d *schema.ResourceData, meta interface{}) error {
if d.HasChange("replication_schedule") {
_, new := d.GetChange("replication_schedule")
replicationSchedule := new.([]interface{})[0].(map[string]interface{})
if v, ok := replicationSchedule["cron"]; ok {
c := v.([]interface{})
log.Printf("[DEBUG] replicationSchedule: %v", replicationSchedule)
log.Printf("[DEBUG] replicationSchedule[cron]: %v", replicationSchedule["cron"])
c := replicationSchedule["cron"].([]interface{})
if len(c) > 0 {
if len(c) > 0 {
cron := c[0].(map[string]interface{})
cronExpression := cron["expression"].(string)
Expand All @@ -579,8 +614,9 @@ func UpdateFailoverGroup(d *schema.ResourceData, meta interface{}) error {
return fmt.Errorf("error updating replication cron schedule for failover group %v err = %w", name, err)
}
}
} else if v, ok := replicationSchedule["interval"]; ok {
interval := v.(int)
} else {
log.Printf("[DEBUG] replicationSchedule[interval]: %v", replicationSchedule["interval"])
interval := replicationSchedule["interval"].(int)
stmt := builder.ChangeReplicationIntervalSchedule(interval)
if err := snowflake.Exec(db, stmt); err != nil {
return fmt.Errorf("error updating replication interval schedule for failover group %v err = %w", name, err)
Expand Down
Loading

0 comments on commit ab638f0

Please sign in to comment.