Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: add test cases for update repl schedule on failover group #1578

Merged
merged 2 commits into from
Feb 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 114 additions & 78 deletions pkg/resources/failover_group.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ import (
"errors"
"fmt"
"log"
"strconv"
"strings"

"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"

"golang.org/x/exp/slices"

"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers"
"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/snowflake"
)

Expand All @@ -20,6 +22,9 @@ var failoverGroupSchema = map[string]*schema.Schema{
Required: true,
ForceNew: true,
Description: "Specifies the identifier for the failover group. The identifier must start with an alphabetic character and cannot contain spaces or special characters unless the identifier string is enclosed in double quotes (e.g. \"My object\"). Identifiers enclosed in double quotes are also case-sensitive.",
DiffSuppressFunc: func(k, old, new string, d *schema.ResourceData) bool {
return strings.EqualFold(old, new)
},
},
"object_types": {
Type: schema.TypeSet,
Expand Down Expand Up @@ -255,8 +260,7 @@ func CreateFailoverGroup(d *schema.ResourceData, meta interface{}) error {
}

d.SetId(name)

return ReadFailoverGroup(d, meta)
return nil
}

// ReadFailoverGroup implements schema.ReadFunc.
Expand All @@ -276,92 +280,121 @@ func ReadFailoverGroup(d *schema.ResourceData, meta interface{}) error {
return fmt.Errorf("error listing failover groups err = %w", err)
}

found := false
var failoverGroup snowflake.FailoverGroup
// find the failover group we are looking for by matching the name
for _, fg := range failoverGroups {
if fg.Name.String == name && fg.AccountLocator.String == accountLocator {
found = true
if err := d.Set("name", fg.Name.String); err != nil {
if strings.EqualFold(fg.Name.String, name) && strings.EqualFold(fg.AccountLocator.String, accountLocator) {
failoverGroup = fg
}
}

if failoverGroup.Name.String == "" {
log.Printf("[DEBUG] failover group (%v) not found when listing all failover groups in account", name)
d.SetId("")
return nil
}

if err := d.Set("name", failoverGroup.Name.String); err != nil {
return err
}
// if the failover group is created from a replica, then we do not want to get the other values
if _, ok := d.GetOk("from_replica"); ok {
log.Printf("[DEBUG] failover group %v is created from a replica, rest of values are computed\n", name)
return nil
}

replicationSchedule := failoverGroup.ReplicationSchedule.String
if replicationSchedule != "" {
if strings.Contains(replicationSchedule, "MINUTE") {
interval, err := strconv.Atoi(strings.TrimSuffix(replicationSchedule, " MINUTE"))
if err != nil {
return err
}
// if the failover group is created from a replica, then we do not want to get the other values
if _, ok := d.GetOk("from_replica"); ok {
log.Printf("[DEBUG] failover group %v is created from a replica, rest of values are computed\n", name)
return nil
err = d.Set("replication_schedule", []interface{}{
map[string]interface{}{
"interval": interval,
},
})
if err != nil {
return err
}

ots := strings.Split(fg.ObjectTypes.String, ",")
var objectTypes []string
for _, v := range ots {
objectType := strings.TrimSpace(v)
if objectType == "" {
continue
}
objectTypes = append(objectTypes, objectType)
} else {
repScheduleParts := strings.Split(replicationSchedule, " ")
timeZone := repScheduleParts[len(repScheduleParts)-1]
expression := strings.TrimSuffix(strings.TrimPrefix(replicationSchedule, "USING CRON "), " "+timeZone)
err = d.Set("replication_schedule", []interface{}{
map[string]interface{}{
"cron": []interface{}{
map[string]interface{}{
"expression": expression,
"time_zone": timeZone,
},
},
},
})
if err != nil {
return err
}
}
}

// this is basically a hack to get around the fact that the API returns the object types in a different order than what is set
// this logic could also be put in the diff suppress function, but I think it is better to do it here.
currentObjectTypeList := d.Get("object_types").(*schema.Set).List()
if len(currentObjectTypeList) != len(objectTypes) {
log.Printf("[DEBUG] object types are different, current: %v, new: %v", currentObjectTypeList, objectTypes)
if err := d.Set("object_types", objectTypes); err != nil {
return err
}
}
objectTypes := helpers.SplitStringToSlice(failoverGroup.ObjectTypes.String, ",")

for _, v := range currentObjectTypeList {
if !slices.Contains(objectTypes, v.(string)) {
log.Printf("[DEBUG] object types are different, current: %v, new: %v", currentObjectTypeList, objectTypes)
if err := d.Set("object_types", objectTypes); err != nil {
return err
}
break
}
}
// this is basically a hack to get around the fact that the API returns the object types in a different order than what is set
// this logic could also be put in the diff suppress function, but I think it is better to do it here.
currentObjectTypeList := d.Get("object_types").(*schema.Set).List()
if len(currentObjectTypeList) != len(objectTypes) {
log.Printf("[DEBUG] object types are different, current: %v, new: %v", currentObjectTypeList, objectTypes)
if err := d.Set("object_types", objectTypes); err != nil {
return err
}
}

allowedIntegrationTypes := fg.AllowedIntegrationTypes.String
if allowedIntegrationTypes != "" {
aits := strings.Split(allowedIntegrationTypes, ",")
var allowedIntegrationTypes []interface{}
for _, v := range aits {
allowedIntegrationType := strings.TrimSpace(v)
if allowedIntegrationType == "" {
continue
}
if allowedIntegrationType == "SECURITY" {
allowedIntegrationType = "SECURITY INTEGRATIONS"
}
allowedIntegrationTypes = append(allowedIntegrationTypes, allowedIntegrationType)
}
allowedIntegrationTypesSet := schema.NewSet(schema.HashString, allowedIntegrationTypes)
if err := d.Set("allowed_integration_types", allowedIntegrationTypesSet); err != nil {
return err
}
for _, v := range currentObjectTypeList {
if !slices.Contains(objectTypes, v.(string)) {
log.Printf("[DEBUG] object types are different, current: %v, new: %v", currentObjectTypeList, objectTypes)
if err := d.Set("object_types", objectTypes); err != nil {
return err
}
break
}
}

allowedAccounts := fg.AllowedAccounts.String
if allowedAccounts != "" {
aa := strings.Split(allowedAccounts, ",")
var allowedAccounts []interface{}
for _, v := range aa {
allowedAccount := strings.TrimSpace(v)
if allowedAccount == "" {
continue
}
allowedAccounts = append(allowedAccounts, allowedAccount)
}
allowedAccountsSet := schema.NewSet(schema.HashString, allowedAccounts)
if err := d.Set("allowed_accounts", allowedAccountsSet); err != nil {
return err
}
allowedIntegrationTypes := failoverGroup.AllowedIntegrationTypes.String
if allowedIntegrationTypes != "" {
aits := strings.Split(allowedIntegrationTypes, ",")
var allowedIntegrationTypes []interface{}
for _, v := range aits {
allowedIntegrationType := strings.TrimSpace(v)
if allowedIntegrationType == "" {
continue
}
if allowedIntegrationType == "SECURITY" {
allowedIntegrationType = "SECURITY INTEGRATIONS"
}
allowedIntegrationTypes = append(allowedIntegrationTypes, allowedIntegrationType)
}
allowedIntegrationTypesSet := schema.NewSet(schema.HashString, allowedIntegrationTypes)
if err := d.Set("allowed_integration_types", allowedIntegrationTypesSet); err != nil {
return err
}
}

if !found {
log.Printf("[DEBUG] failover group (%v) not found when listing all failover groups in account", name)
d.SetId("")
return nil
allowedAccounts := failoverGroup.AllowedAccounts.String
if allowedAccounts != "" {
aa := strings.Split(allowedAccounts, ",")
var allowedAccounts []interface{}
for _, v := range aa {
allowedAccount := strings.TrimSpace(v)
if allowedAccount == "" {
continue
}
allowedAccounts = append(allowedAccounts, allowedAccount)
}
allowedAccountsSet := schema.NewSet(schema.HashString, allowedAccounts)
if err := d.Set("allowed_accounts", allowedAccountsSet); err != nil {
return err
}
}

allowedDatabases, err := snowflake.ShowDatabasesInFailoverGroup(name, db)
Expand Down Expand Up @@ -564,8 +597,10 @@ func UpdateFailoverGroup(d *schema.ResourceData, meta interface{}) error {
if d.HasChange("replication_schedule") {
_, new := d.GetChange("replication_schedule")
replicationSchedule := new.([]interface{})[0].(map[string]interface{})
if v, ok := replicationSchedule["cron"]; ok {
c := v.([]interface{})
log.Printf("[DEBUG] replicationSchedule: %v", replicationSchedule)
log.Printf("[DEBUG] replicationSchedule[cron]: %v", replicationSchedule["cron"])
c := replicationSchedule["cron"].([]interface{})
if len(c) > 0 {
if len(c) > 0 {
cron := c[0].(map[string]interface{})
cronExpression := cron["expression"].(string)
Expand All @@ -579,8 +614,9 @@ func UpdateFailoverGroup(d *schema.ResourceData, meta interface{}) error {
return fmt.Errorf("error updating replication cron schedule for failover group %v err = %w", name, err)
}
}
} else if v, ok := replicationSchedule["interval"]; ok {
interval := v.(int)
} else {
log.Printf("[DEBUG] replicationSchedule[interval]: %v", replicationSchedule["interval"])
interval := replicationSchedule["interval"].(int)
stmt := builder.ChangeReplicationIntervalSchedule(interval)
if err := snowflake.Exec(db, stmt); err != nil {
return fmt.Errorf("error updating replication interval schedule for failover group %v err = %w", name, err)
Expand Down
Loading