Skip to content

Commit

Permalink
fix: Fixes network_container resource update (#2055)
Browse files Browse the repository at this point in the history
* failing test

* recreate resource if provider_name changes

* fix update

* allow update in place when provider_name changes

* fix regions error handling

* refactor tests checks

* refactor checks in mig tests as well

* fix change check
  • Loading branch information
lantoli authored Mar 21, 2024
1 parent 9229ae4 commit 39ace4e
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 129 deletions.
54 changes: 29 additions & 25 deletions internal/service/networkcontainer/resource_network_container.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"fmt"
"log"
"net/http"
"reflect"
"strings"
"time"

Expand Down Expand Up @@ -211,42 +210,47 @@ func resourceRead(ctx context.Context, d *schema.ResourceData, meta any) diag.Di
}

func resourceUpdate(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics {
if !d.HasChange("provider_name") && !d.HasChange("atlas_cidr_block") && !d.HasChange("region_name") && !d.HasChange("region") && !d.HasChange("regions") {
return resourceRead(ctx, d, meta)
}

connV2 := meta.(*config.MongoDBClient).AtlasV2
ids := conversion.DecodeStateID(d.Id())
projectID := ids["project_id"]
containerID := ids["container_id"]

container := new(admin.CloudProviderContainer)

if d.HasChange("atlas_cidr_block") {
atlasCidrBlock := d.Get("atlas_cidr_block").(string)
providerName := d.Get("provider_name").(string)
container.AtlasCidrBlock = &atlasCidrBlock
container.ProviderName = &providerName
}

if d.HasChange("provider_name") {
providerName := d.Get("provider_name").(string)
container.ProviderName = &providerName
}

if d.HasChange("region_name") {
regionName, _ := conversion.ValRegion(d.Get("region_name"))
container.RegionName = &regionName
}
providerName := d.Get("provider_name").(string)
cidr := d.Get("atlas_cidr_block").(string)

if d.HasChange("region") {
region, _ := conversion.ValRegion(d.Get("region"))
container.Region = &region
params := &admin.CloudProviderContainer{
ProviderName: conversion.StringPtr(providerName),
AtlasCidrBlock: conversion.StringPtr(cidr),
}

if !reflect.DeepEqual(container, admin.CloudProviderContainer{}) {
_, _, err := connV2.NetworkPeeringApi.UpdatePeeringContainer(ctx, projectID, containerID, container).Execute()
switch providerName {
case constant.AWS:
regionName, err := conversion.ValRegion(d.Get("region_name"))
if err != nil {
return diag.FromErr(fmt.Errorf(errorContainerUpdate, containerID, err))
}
params.SetRegionName(regionName)
case constant.AZURE:
region, err := conversion.ValRegion(d.Get("region"))
if err != nil {
return diag.FromErr(fmt.Errorf(errorContainerUpdate, containerID, err))
}
params.SetRegion(region)
case constant.GCP:
if regionList, ok := d.GetOk("regions"); ok {
if regions := cast.ToStringSlice(regionList); regions != nil {
params.SetRegions(regions)
}
}
}
_, _, err := connV2.NetworkPeeringApi.UpdatePeeringContainer(ctx, projectID, containerID, params).Execute()
if err != nil {
return diag.FromErr(fmt.Errorf(errorContainerUpdate, containerID, err))
}

return resourceRead(ctx, d, meta)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,7 @@ func TestMigNetworkContainer_basicAWS(t *testing.T) {
{
ExternalProviders: mig.ExternalProviders(),
Config: config,
Check: resource.ComposeTestCheckFunc(
checkExists(resourceName),
resource.TestCheckResourceAttrSet(resourceName, "project_id"),
resource.TestCheckResourceAttr(resourceName, "provider_name", constant.AWS),
resource.TestCheckResourceAttrSet(resourceName, "provisioned"),
),
Check: resource.ComposeTestCheckFunc(commonChecks(constant.AWS)...),
},
mig.TestStepCheckEmptyPlan(config),
},
Expand All @@ -53,12 +48,7 @@ func TestMigNetworkContainer_basicAzure(t *testing.T) {
{
ExternalProviders: mig.ExternalProviders(),
Config: config,
Check: resource.ComposeTestCheckFunc(
checkExists(resourceName),
resource.TestCheckResourceAttrSet(resourceName, "project_id"),
resource.TestCheckResourceAttr(resourceName, "provider_name", constant.AZURE),
resource.TestCheckResourceAttrSet(resourceName, "provisioned"),
),
Check: resource.ComposeTestCheckFunc(commonChecks(constant.AZURE)...),
},
mig.TestStepCheckEmptyPlan(config),
},
Expand All @@ -80,12 +70,7 @@ func TestMigNetworkContainer_basicGCP(t *testing.T) {
{
ExternalProviders: mig.ExternalProviders(),
Config: config,
Check: resource.ComposeTestCheckFunc(
checkExists(resourceName),
resource.TestCheckResourceAttrSet(resourceName, "project_id"),
resource.TestCheckResourceAttr(resourceName, "provider_name", constant.GCP),
resource.TestCheckResourceAttrSet(resourceName, "provisioned"),
),
Check: resource.ComposeTestCheckFunc(commonChecks(constant.GCP)...),
},
mig.TestStepCheckEmptyPlan(config),
},
Expand Down
148 changes: 62 additions & 86 deletions internal/service/networkcontainer/resource_network_container_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func TestAccNetworkContainer_basicAWS(t *testing.T) {
projectID = acc.ProjectIDExecution(t)
randInt = acctest.RandIntRange(0, 255)
cidrBlock = fmt.Sprintf("10.8.%d.0/24", randInt)
randIntUpdated = acctest.RandIntRange(0, 255)
randIntUpdated = (randInt + 1) % 256
cidrBlockUpdated = fmt.Sprintf("10.8.%d.0/24", randIntUpdated)
)

Expand All @@ -35,31 +35,11 @@ func TestAccNetworkContainer_basicAWS(t *testing.T) {
Steps: []resource.TestStep{
{
Config: configBasic(projectID, cidrBlock, constant.AWS, "US_EAST_1"),
Check: resource.ComposeTestCheckFunc(
checkExists(resourceName),
resource.TestCheckResourceAttrSet(resourceName, "project_id"),
resource.TestCheckResourceAttr(resourceName, "provider_name", constant.AWS),
resource.TestCheckResourceAttrSet(resourceName, "provisioned"),

resource.TestCheckResourceAttrSet(dataSourceName, "project_id"),
resource.TestCheckResourceAttr(dataSourceName, "provider_name", constant.AWS),
resource.TestCheckResourceAttrSet(dataSourceName, "provisioned"),

resource.TestCheckResourceAttrWith(dataSourcePluralName, "results.#", acc.IntGreatThan(0)),
resource.TestCheckResourceAttrSet(dataSourcePluralName, "results.0.id"),
resource.TestCheckResourceAttrSet(dataSourcePluralName, "results.0.atlas_cidr_block"),
resource.TestCheckResourceAttrSet(dataSourcePluralName, "results.0.provider_name"),
resource.TestCheckResourceAttrSet(dataSourcePluralName, "results.0.provisioned"),
),
Check: resource.ComposeTestCheckFunc(commonChecks(constant.AWS)...),
},
{
Config: configBasic(projectID, cidrBlockUpdated, constant.AWS, "US_WEST_2"),
Check: resource.ComposeTestCheckFunc(
checkExists(resourceName),
resource.TestCheckResourceAttrSet(resourceName, "project_id"),
resource.TestCheckResourceAttr(resourceName, "provider_name", constant.AWS),
resource.TestCheckResourceAttrSet(resourceName, "provisioned"),
),
Config: configBasic(projectID, cidrBlockUpdated, constant.AWS, "US_EAST_2"),
Check: resource.ComposeTestCheckFunc(commonChecks(constant.AWS)...),
},
},
})
Expand All @@ -69,7 +49,7 @@ func TestAccNetworkContainer_basicAzure(t *testing.T) {
var (
randInt = acctest.RandIntRange(0, 255)
cidrBlock = fmt.Sprintf("10.8.%d.0/24", randInt)
randIntUpdated = acctest.RandIntRange(0, 255)
randIntUpdated = (randInt + 1) % 256
cidrBlockUpdated = fmt.Sprintf("10.8.%d.0/24", randIntUpdated)
projectID = acc.ProjectIDExecution(t)
)
Expand All @@ -81,31 +61,11 @@ func TestAccNetworkContainer_basicAzure(t *testing.T) {
Steps: []resource.TestStep{
{
Config: configBasic(projectID, cidrBlock, constant.AZURE, "US_EAST_2"),
Check: resource.ComposeTestCheckFunc(
checkExists(resourceName),
resource.TestCheckResourceAttrSet(resourceName, "project_id"),
resource.TestCheckResourceAttr(resourceName, "provider_name", constant.AZURE),
resource.TestCheckResourceAttrSet(resourceName, "provisioned"),

resource.TestCheckResourceAttrSet(dataSourceName, "project_id"),
resource.TestCheckResourceAttr(dataSourceName, "provider_name", constant.AZURE),
resource.TestCheckResourceAttrSet(dataSourceName, "provisioned"),

resource.TestCheckResourceAttrWith(dataSourcePluralName, "results.#", acc.IntGreatThan(0)),
resource.TestCheckResourceAttrSet(dataSourcePluralName, "results.0.id"),
resource.TestCheckResourceAttrSet(dataSourcePluralName, "results.0.atlas_cidr_block"),
resource.TestCheckResourceAttrSet(dataSourcePluralName, "results.0.provider_name"),
resource.TestCheckResourceAttrSet(dataSourcePluralName, "results.0.provisioned"),
),
Check: resource.ComposeTestCheckFunc(commonChecks(constant.AZURE)...),
},
{
Config: configBasic(projectID, cidrBlockUpdated, constant.AZURE, "US_EAST_2"),
Check: resource.ComposeTestCheckFunc(
checkExists(resourceName),
resource.TestCheckResourceAttrSet(resourceName, "project_id"),
resource.TestCheckResourceAttr(resourceName, "provider_name", constant.AZURE),
resource.TestCheckResourceAttrSet(resourceName, "provisioned"),
),
Check: resource.ComposeTestCheckFunc(commonChecks(constant.AZURE)...),
},
},
})
Expand All @@ -115,7 +75,7 @@ func TestAccNetworkContainer_basicGCP(t *testing.T) {
var (
randInt = acctest.RandIntRange(0, 255)
gcpCidrBlock = fmt.Sprintf("10.%d.0.0/18", randInt)
randIntUpdated = acctest.RandIntRange(0, 255)
randIntUpdated = (randInt + 1) % 256
cidrBlockUpdated = fmt.Sprintf("10.%d.0.0/18", randIntUpdated)
projectID = acc.ProjectIDExecution(t)
)
Expand All @@ -127,31 +87,11 @@ func TestAccNetworkContainer_basicGCP(t *testing.T) {
Steps: []resource.TestStep{
{
Config: configBasic(projectID, gcpCidrBlock, constant.GCP, ""),
Check: resource.ComposeTestCheckFunc(
checkExists(resourceName),
resource.TestCheckResourceAttrSet(resourceName, "project_id"),
resource.TestCheckResourceAttr(resourceName, "provider_name", constant.GCP),
resource.TestCheckResourceAttrSet(resourceName, "provisioned"),

resource.TestCheckResourceAttrSet(dataSourceName, "project_id"),
resource.TestCheckResourceAttr(dataSourceName, "provider_name", constant.GCP),
resource.TestCheckResourceAttrSet(dataSourceName, "provisioned"),

resource.TestCheckResourceAttrWith(dataSourcePluralName, "results.#", acc.IntGreatThan(0)),
resource.TestCheckResourceAttrSet(dataSourcePluralName, "results.0.id"),
resource.TestCheckResourceAttrSet(dataSourcePluralName, "results.0.atlas_cidr_block"),
resource.TestCheckResourceAttrSet(dataSourcePluralName, "results.0.provider_name"),
resource.TestCheckResourceAttrSet(dataSourcePluralName, "results.0.provisioned"),
),
Check: resource.ComposeTestCheckFunc(commonChecks(constant.GCP)...),
},
{
Config: configBasic(projectID, cidrBlockUpdated, constant.GCP, ""),
Check: resource.ComposeTestCheckFunc(
checkExists(resourceName),
resource.TestCheckResourceAttrSet(resourceName, "project_id"),
resource.TestCheckResourceAttr(resourceName, "provider_name", constant.GCP),
resource.TestCheckResourceAttrSet(resourceName, "provisioned"),
),
Check: resource.ComposeTestCheckFunc(commonChecks(constant.GCP)...),
},
},
})
Expand All @@ -172,22 +112,7 @@ func TestAccNetworkContainer_withRegionsGCP(t *testing.T) {
Steps: []resource.TestStep{
{
Config: configBasic(projectID, gcpWithRegionsCidrBlock, constant.GCP, regions),
Check: resource.ComposeTestCheckFunc(
checkExists(resourceName),
resource.TestCheckResourceAttrSet(resourceName, "project_id"),
resource.TestCheckResourceAttr(resourceName, "provider_name", constant.GCP),
resource.TestCheckResourceAttrSet(resourceName, "provisioned"),

resource.TestCheckResourceAttrSet(dataSourceName, "project_id"),
resource.TestCheckResourceAttr(dataSourceName, "provider_name", constant.GCP),
resource.TestCheckResourceAttrSet(dataSourceName, "provisioned"),

resource.TestCheckResourceAttrWith(dataSourcePluralName, "results.#", acc.IntGreatThan(0)),
resource.TestCheckResourceAttrSet(dataSourcePluralName, "results.0.id"),
resource.TestCheckResourceAttrSet(dataSourcePluralName, "results.0.atlas_cidr_block"),
resource.TestCheckResourceAttrSet(dataSourcePluralName, "results.0.provider_name"),
resource.TestCheckResourceAttrSet(dataSourcePluralName, "results.0.provisioned"),
),
Check: resource.ComposeTestCheckFunc(commonChecks(constant.GCP)...),
},
},
})
Expand Down Expand Up @@ -218,6 +143,38 @@ func TestAccNetworkContainer_importBasic(t *testing.T) {
})
}

func TestAccNetworkContainer_updateIndividualFields(t *testing.T) {
var (
projectID = acc.ProjectIDExecution(t)
randInt = acctest.RandIntRange(0, 255)
cidrBlock = fmt.Sprintf("10.8.%d.0/24", randInt)
randIntUpdated = (randInt + 1) % 256
cidrBlockUpdated = fmt.Sprintf("10.8.%d.0/24", randIntUpdated)
region = "EU_WEST_1"
regionUpdated = "EU_WEST_2"
)

resource.ParallelTest(t, resource.TestCase{
PreCheck: func() { acc.PreCheckBasic(t) },
ProtoV6ProviderFactories: acc.TestAccProviderV6Factories,
CheckDestroy: checkDestroy,
Steps: []resource.TestStep{
{
Config: configBasic(projectID, cidrBlock, constant.AWS, region),
Check: resource.ComposeTestCheckFunc(commonChecks(constant.AWS)...),
},
{
Config: configBasic(projectID, cidrBlockUpdated, constant.AWS, region),
Check: resource.ComposeTestCheckFunc(commonChecks(constant.AWS)...),
},
{
Config: configBasic(projectID, cidrBlockUpdated, constant.AWS, regionUpdated),
Check: resource.ComposeTestCheckFunc(commonChecks(constant.AWS)...),
},
},
})
}

func importStateIDFunc(resourceName string) resource.ImportStateIdFunc {
return func(s *terraform.State) (string, error) {
rs, ok := s.RootModule().Resources[resourceName]
Expand Down Expand Up @@ -261,6 +218,25 @@ func checkDestroy(s *terraform.State) error {
return nil
}

func commonChecks(providerName string) []resource.TestCheckFunc {
return []resource.TestCheckFunc{
checkExists(resourceName),
resource.TestCheckResourceAttrSet(resourceName, "project_id"),
resource.TestCheckResourceAttr(resourceName, "provider_name", providerName),
resource.TestCheckResourceAttrSet(resourceName, "provisioned"),

resource.TestCheckResourceAttrSet(dataSourceName, "project_id"),
resource.TestCheckResourceAttr(dataSourceName, "provider_name", providerName),
resource.TestCheckResourceAttrSet(dataSourceName, "provisioned"),

resource.TestCheckResourceAttrWith(dataSourcePluralName, "results.#", acc.IntGreatThan(0)),
resource.TestCheckResourceAttrSet(dataSourcePluralName, "results.0.id"),
resource.TestCheckResourceAttrSet(dataSourcePluralName, "results.0.atlas_cidr_block"),
resource.TestCheckResourceAttrSet(dataSourcePluralName, "results.0.provider_name"),
resource.TestCheckResourceAttrSet(dataSourcePluralName, "results.0.provisioned"),
}
}

func configBasic(projectID, cidrBlock, providerName, region string) string {
var regionStr string
if region != "" {
Expand Down

0 comments on commit 39ace4e

Please sign in to comment.