diff --git a/azurerm/internal/services/policy/policy_remediation_resource.go b/azurerm/internal/services/policy/policy_remediation_resource.go index 1508d5d56986..869b0e7692d3 100644 --- a/azurerm/internal/services/policy/policy_remediation_resource.go +++ b/azurerm/internal/services/policy/policy_remediation_resource.go @@ -7,7 +7,9 @@ import ( "time" "github.com/Azure/azure-sdk-for-go/services/preview/policyinsights/mgmt/2019-10-01-preview/policyinsights" + "github.com/hashicorp/terraform-plugin-sdk/helper/resource" "github.com/hashicorp/terraform-plugin-sdk/helper/schema" + "github.com/hashicorp/terraform-plugin-sdk/helper/validation" "github.com/terraform-providers/terraform-provider-azurerm/azurerm/helpers/suppress" "github.com/terraform-providers/terraform-provider-azurerm/azurerm/helpers/tf" "github.com/terraform-providers/terraform-provider-azurerm/azurerm/internal/clients" @@ -77,6 +79,16 @@ func resourceArmPolicyRemediation() *schema.Resource { // TODO: remove this suppression when github issue https://github.com/Azure/azure-rest-api-specs/issues/8353 is addressed DiffSuppressFunc: suppress.CaseDifference, }, + + "resource_discovery_mode": { + Type: schema.TypeString, + Optional: true, + Default: string(policyinsights.ExistingNonCompliant), + ValidateFunc: validation.StringInSlice([]string{ + string(policyinsights.ExistingNonCompliant), + string(policyinsights.ReEvaluateCompliance), + }, false), + }, }, } } @@ -111,6 +123,7 @@ func resourceArmPolicyRemediationCreateUpdate(d *schema.ResourceData, meta inter }, PolicyAssignmentID: utils.String(d.Get("policy_assignment_id").(string)), PolicyDefinitionReferenceID: utils.String(d.Get("policy_definition_reference_id").(string)), + ResourceDiscoveryMode: policyinsights.ResourceDiscoveryMode(d.Get("resource_discovery_mode").(string)), }, } @@ -177,6 +190,7 @@ func resourceArmPolicyRemediationRead(d *schema.ResourceData, meta interface{}) d.Set("policy_assignment_id", props.PolicyAssignmentID) d.Set("policy_definition_reference_id", props.PolicyDefinitionReferenceID) + d.Set("resource_discovery_mode", string(props.ResourceDiscoveryMode)) } return nil @@ -192,6 +206,38 @@ func resourceArmPolicyRemediationDelete(d *schema.ResourceData, meta interface{} return err } + // we have to cancel the remediation first before deleting it when the resource_discovery_mode is set to ReEvaluateCompliance + // therefore we first retrieve the remediation to see if the resource_discovery_mode is switched to ReEvaluateCompliance + existing, err := RemediationGetAtScope(ctx, client, id.Name, id.PolicyScopeId) + if err != nil { + if utils.ResponseWasNotFound(existing.Response) { + return nil + } + return fmt.Errorf("retrieving Policy Remediation %q (Scope %q): %+v", id.Name, id.ScopeId(), err) + } + + if existing.RemediationProperties != nil && existing.RemediationProperties.ResourceDiscoveryMode == policyinsights.ReEvaluateCompliance { + log.Printf("[DEBUG] cancelling the remediation first before deleting it when `resource_discovery_mode` is set to `ReEvaluateCompliance`") + if err := cancelRemediation(ctx, client, id.Name, id.PolicyScopeId); err != nil { + return fmt.Errorf("cancelling Policy Remediation %q (Scope %q): %+v", id.Name, id.ScopeId(), err) + } + + log.Printf("[DEBUG] waiting for the Policy Remediation %q (Scope %q) to be canceled", id.Name, id.ScopeId()) + stateConf := &resource.StateChangeConf{ + Pending: []string{"Cancelling"}, + Target: []string{ + "Succeeded", "Canceled", "Failed", + }, + Refresh: policyRemediationCancellationRefreshFunc(ctx, client, id.Name, id.PolicyScopeId), + MinTimeout: 10 * time.Second, + Timeout: d.Timeout(schema.TimeoutDelete), + } + + if _, err := stateConf.WaitForState(); err != nil { + return fmt.Errorf("waiting for Policy Remediation %q to be canceled: %+v", id.Name, err) + } + } + switch scope := id.PolicyScopeId.(type) { case parse.ScopeAtSubscription: _, err = client.DeleteAtSubscription(ctx, scope.SubscriptionId, id.Name) @@ -211,6 +257,42 @@ func resourceArmPolicyRemediationDelete(d *schema.ResourceData, meta interface{} return nil } +func cancelRemediation(ctx context.Context, client *policyinsights.RemediationsClient, name string, scopeId parse.PolicyScopeId) error { + switch scopeId := scopeId.(type) { + case parse.ScopeAtSubscription: + _, err := client.CancelAtSubscription(ctx, scopeId.SubscriptionId, name) + return err + case parse.ScopeAtResourceGroup: + _, err := client.CancelAtResourceGroup(ctx, scopeId.SubscriptionId, scopeId.ResourceGroup, name) + return err + case parse.ScopeAtResource: + _, err := client.CancelAtResource(ctx, scopeId.ScopeId(), name) + return err + case parse.ScopeAtManagementGroup: + _, err := client.CancelAtManagementGroup(ctx, scopeId.ManagementGroupName, name) + return err + default: + return fmt.Errorf("nvalid scope type") + } +} + +func policyRemediationCancellationRefreshFunc(ctx context.Context, client *policyinsights.RemediationsClient, name string, scopeId parse.PolicyScopeId) resource.StateRefreshFunc { + return func() (interface{}, string, error) { + resp, err := RemediationGetAtScope(ctx, client, name, scopeId) + if err != nil { + return nil, "", fmt.Errorf("issuing read request in policyRemediationCancellationRefreshFunc for Policy Remediation %q (Scope %q): %+v", name, scopeId.ScopeId(), err) + } + + if resp.RemediationProperties == nil { + return nil, "", fmt.Errorf("`properties` was nil") + } + if resp.RemediationProperties.ProvisioningState == nil { + return nil, "", fmt.Errorf("`properties.ProvisioningState` was nil") + } + return resp, *resp.RemediationProperties.ProvisioningState, nil + } +} + // RemediationGetAtScope is a wrapper of the 4 Get functions on RemediationsClient, combining them into one to simplify code. func RemediationGetAtScope(ctx context.Context, client *policyinsights.RemediationsClient, name string, scopeId parse.PolicyScopeId) (policyinsights.Remediation, error) { switch scopeId := scopeId.(type) { diff --git a/azurerm/internal/services/policy/tests/policy_remediation_resource_test.go b/azurerm/internal/services/policy/tests/policy_remediation_resource_test.go index 436eea38b681..f8eeb1bee594 100644 --- a/azurerm/internal/services/policy/tests/policy_remediation_resource_test.go +++ b/azurerm/internal/services/policy/tests/policy_remediation_resource_test.go @@ -25,8 +25,6 @@ func TestAccAzureRMPolicyRemediation_atSubscription(t *testing.T) { Config: testAccAzureRMPolicyRemediation_atSubscription(data), Check: resource.ComposeTestCheckFunc( testCheckAzureRMPolicyRemediationExists(data.ResourceName), - resource.TestCheckResourceAttrSet(data.ResourceName, "scope"), - resource.TestCheckResourceAttrSet(data.ResourceName, "policy_assignment_id"), ), }, data.ImportStep(), @@ -47,8 +45,6 @@ func TestAccAzureRMPolicyRemediation_atSubscriptionWithDefinitionSet(t *testing. Check: resource.ComposeTestCheckFunc( testCheckAzureRMPolicyRemediationExists(data.ResourceName), resource.TestCheckResourceAttrSet(data.ResourceName, "scope"), - resource.TestCheckResourceAttrSet(data.ResourceName, "policy_assignment_id"), - resource.TestCheckResourceAttrSet(data.ResourceName, "policy_definition_reference_id"), ), }, data.ImportStep(), @@ -68,8 +64,25 @@ func TestAccAzureRMPolicyRemediation_atResourceGroup(t *testing.T) { Config: testAccAzureRMPolicyRemediation_atResourceGroup(data), Check: resource.ComposeTestCheckFunc( testCheckAzureRMPolicyRemediationExists(data.ResourceName), - resource.TestCheckResourceAttrSet(data.ResourceName, "scope"), - resource.TestCheckResourceAttrSet(data.ResourceName, "policy_assignment_id"), + ), + }, + data.ImportStep(), + }, + }) +} + +func TestAccAzureRMPolicyRemediation_atResourceGroupWithDiscoveryMode(t *testing.T) { + data := acceptance.BuildTestData(t, "azurerm_policy_remediation", "test") + + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { acceptance.PreCheck(t) }, + Providers: acceptance.SupportedProviders, + CheckDestroy: testCheckAzureRMPolicyRemediationDestroy, + Steps: []resource.TestStep{ + { + Config: testAccAzureRMPolicyRemediation_atResourceGroupWithDiscoveryMode(data), + Check: resource.ComposeTestCheckFunc( + testCheckAzureRMPolicyRemediationExists(data.ResourceName), ), }, data.ImportStep(), @@ -89,8 +102,6 @@ func TestAccAzureRMPolicyRemediation_atManagementGroup(t *testing.T) { Config: testAccAzureRMPolicyRemediation_atManagementGroup(data), Check: resource.ComposeTestCheckFunc( testCheckAzureRMPolicyRemediationExists(data.ResourceName), - resource.TestCheckResourceAttrSet(data.ResourceName, "scope"), - resource.TestCheckResourceAttrSet(data.ResourceName, "policy_assignment_id"), ), }, data.ImportStep(), @@ -110,8 +121,6 @@ func TestAccAzureRMPolicyRemediation_atResource(t *testing.T) { Config: testAccAzureRMPolicyRemediation_atResource(data), Check: resource.ComposeTestCheckFunc( testCheckAzureRMPolicyRemediationExists(data.ResourceName), - resource.TestCheckResourceAttrSet(data.ResourceName, "scope"), - resource.TestCheckResourceAttrSet(data.ResourceName, "policy_assignment_id"), ), }, data.ImportStep(), @@ -131,14 +140,13 @@ func TestAccAzureRMPolicyRemediation_updateLocation(t *testing.T) { Config: testAccAzureRMPolicyRemediation_atResourceGroup(data), Check: resource.ComposeTestCheckFunc( testCheckAzureRMPolicyRemediationExists(data.ResourceName), - resource.TestCheckResourceAttr(data.ResourceName, "location_filters.#", "0"), ), }, + data.ImportStep(), { Config: testAccAzureRMPolicyRemediation_updateLocation(data), Check: resource.ComposeTestCheckFunc( testCheckAzureRMPolicyRemediationExists(data.ResourceName), - resource.TestCheckResourceAttr(data.ResourceName, "location_filters.#", "1"), ), }, data.ImportStep(), @@ -451,6 +459,77 @@ resource "azurerm_policy_remediation" "test" { `, data.RandomString, data.Locations.Primary) } +func testAccAzureRMPolicyRemediation_atResourceGroupWithDiscoveryMode(data acceptance.TestData) string { + return fmt.Sprintf(` +provider "azurerm" { + features {} +} + +resource "azurerm_resource_group" "test" { + name = "acctestRG-policy-%[1]s" + location = "%[2]s" +} + +resource "azurerm_policy_definition" "test" { + name = "acctestDef-%[1]s" + policy_type = "Custom" + mode = "All" + display_name = "my-policy-definition" + + policy_rule = <