Skip to content

Commit

Permalink
Merge pull request #38574 from madhavvishnubhatta/f-sfn-kms-integration
Browse files Browse the repository at this point in the history
Added support for "EncryptionConfiguration" in StateMachine and Activity Resources
  • Loading branch information
ewbankkit authored Jul 29, 2024
2 parents f824dab + 1604008 commit cbd253d
Show file tree
Hide file tree
Showing 8 changed files with 588 additions and 12 deletions.
7 changes: 7 additions & 0 deletions .changelog/38574.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
```release-note:enhancement
resource/aws_sfn_activity: Add `encryption_configuration` configuration block to support the use of Customer Managed Keys with AWS KMS to encrypt Step Functions Activity resources
```

```release-note:enhancement
resource/aws_sfn_state_machine: Add `encryption_configuration` configuration block to support the use of Customer Managed Keys with AWS KMS to encrypt Step Functions State Machine resources
```
37 changes: 37 additions & 0 deletions internal/service/sfn/activity.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation"
"github.com/hashicorp/terraform-provider-aws/internal/conns"
"github.com/hashicorp/terraform-provider-aws/internal/enum"
"github.com/hashicorp/terraform-provider-aws/internal/errs"
"github.com/hashicorp/terraform-provider-aws/internal/errs/sdkdiag"
tftags "github.com/hashicorp/terraform-provider-aws/internal/tags"
Expand All @@ -42,6 +43,31 @@ func resourceActivity() *schema.Resource {
Type: schema.TypeString,
Computed: true,
},
names.AttrEncryptionConfiguration: {
Type: schema.TypeList,
Optional: true,
Computed: true,
MaxItems: 1,
Elem: &schema.Resource{
Schema: map[string]*schema.Schema{
"kms_data_key_reuse_period_seconds": {
Type: schema.TypeInt,
Optional: true,
ValidateFunc: validation.IntBetween(60, 900),
},
names.AttrKMSKeyID: {
Type: schema.TypeString,
Optional: true,
},
names.AttrType: {
Type: schema.TypeString,
Optional: true,
ValidateDiagFunc: enum.Validate[awstypes.EncryptionType](),
},
},
},
DiffSuppressFunc: verify.SuppressMissingOptionalConfigurationBlock,
},
names.AttrName: {
Type: schema.TypeString,
Required: true,
Expand All @@ -66,6 +92,10 @@ func resourceActivityCreate(ctx context.Context, d *schema.ResourceData, meta in
Tags: getTagsIn(ctx),
}

if v, ok := d.GetOk(names.AttrEncryptionConfiguration); ok && len(v.([]interface{})) > 0 && v.([]interface{})[0] != nil {
input.EncryptionConfiguration = expandEncryptionConfiguration(v.([]interface{})[0].(map[string]interface{}))
}

output, err := conn.CreateActivity(ctx, input)

if err != nil {
Expand Down Expand Up @@ -94,6 +124,13 @@ func resourceActivityRead(ctx context.Context, d *schema.ResourceData, meta inte
}

d.Set(names.AttrCreationDate, output.CreationDate.Format(time.RFC3339))
if output.EncryptionConfiguration != nil {
if err := d.Set(names.AttrEncryptionConfiguration, []interface{}{flattenEncryptionConfiguration(output.EncryptionConfiguration)}); err != nil {
return sdkdiag.AppendErrorf(diags, "setting encryption_configuration: %s", err)
}
} else {
d.Set(names.AttrEncryptionConfiguration, nil)
}
d.Set(names.AttrName, output.Name)

return diags
Expand Down
93 changes: 93 additions & 0 deletions internal/service/sfn/activity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ package sfn_test
import (
"context"
"fmt"
"strconv"
"testing"
"time"

awstypes "github.com/aws/aws-sdk-go-v2/service/sfn/types"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/retry"
sdkacctest "github.com/hashicorp/terraform-plugin-testing/helper/acctest"
"github.com/hashicorp/terraform-plugin-testing/helper/resource"
Expand Down Expand Up @@ -117,6 +119,67 @@ func TestAccSFNActivity_tags(t *testing.T) {
})
}

func TestAccSFNActivity_encryptionConfigurationCustomerManagedKMSKey(t *testing.T) {
ctx := acctest.Context(t)
rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix)
resourceName := "aws_sfn_activity.test"
reusePeriodSeconds := 900
kmsKeyResource := "aws_kms_key.kms_key_for_sfn"

resource.Test(t, resource.TestCase{
PreCheck: func() { acctest.PreCheck(ctx, t) },
ErrorCheck: acctest.ErrorCheck(t, names.SFNServiceID),
ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories,
CheckDestroy: testAccCheckActivityDestroy(ctx),
Steps: []resource.TestStep{
{
Config: testAccActivityConfig_encryptionConfigurationCustomerManagedKMSKey(rName, string(awstypes.EncryptionTypeCustomerManagedKmsKey), reusePeriodSeconds),
Check: resource.ComposeTestCheckFunc(
testAccCheckActivityExists(ctx, resourceName),
resource.TestCheckResourceAttr(resourceName, "encryption_configuration.#", acctest.Ct1),
resource.TestCheckResourceAttr(resourceName, "encryption_configuration.0.type", string(awstypes.EncryptionTypeCustomerManagedKmsKey)),
resource.TestCheckResourceAttr(resourceName, "encryption_configuration.0.kms_data_key_reuse_period_seconds", strconv.Itoa(reusePeriodSeconds)),
resource.TestCheckResourceAttrSet(resourceName, "encryption_configuration.0.kms_key_id"),
resource.TestCheckResourceAttrPair(resourceName, "encryption_configuration.0.kms_key_id", kmsKeyResource, names.AttrARN),
),
},
{
ResourceName: resourceName,
ImportState: true,
ImportStateVerify: true,
},
},
})
}

func TestAccSFNActivity_encryptionConfigurationServiceOwnedKey(t *testing.T) {
ctx := acctest.Context(t)
rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix)
resourceName := "aws_sfn_activity.test"

resource.Test(t, resource.TestCase{
PreCheck: func() { acctest.PreCheck(ctx, t) },
ErrorCheck: acctest.ErrorCheck(t, names.SFNServiceID),
ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories,
CheckDestroy: testAccCheckActivityDestroy(ctx),
Steps: []resource.TestStep{
{
Config: testAccActivityConfig_encryptionConfigurationServiceOwnedKey(rName, string(awstypes.EncryptionTypeAwsOwnedKey)),
Check: resource.ComposeTestCheckFunc(
testAccCheckActivityExists(ctx, resourceName),
resource.TestCheckResourceAttr(resourceName, "encryption_configuration.#", acctest.Ct1),
resource.TestCheckResourceAttr(resourceName, "encryption_configuration.0.type", string(awstypes.EncryptionTypeAwsOwnedKey)),
),
},
{
ResourceName: resourceName,
ImportState: true,
ImportStateVerify: true,
},
},
})
}

func testAccCheckActivityExists(ctx context.Context, n string) resource.TestCheckFunc {
return func(s *terraform.State) error {
rs, ok := s.RootModule().Resources[n]
Expand Down Expand Up @@ -163,6 +226,12 @@ func testAccCheckActivityDestroy(ctx context.Context) resource.TestCheckFunc {
}
}

func testAccActivityConfig_kmsBase() string {
return `
resource "aws_kms_key" "kms_key_for_sfn" {}
`
}

func testAccActivityConfig_basic(rName string) string {
return fmt.Sprintf(`
resource "aws_sfn_activity" "test" {
Expand Down Expand Up @@ -195,3 +264,27 @@ resource "aws_sfn_activity" "test" {
}
`, rName, tag1Key, tag1Value, tag2Key, tag2Value)
}

func testAccActivityConfig_encryptionConfigurationCustomerManagedKMSKey(rName string, rType string, reusePeriodSeconds int) string {
return acctest.ConfigCompose(testAccActivityConfig_kmsBase(), fmt.Sprintf(`
resource "aws_sfn_activity" "test" {
name = %[1]q
encryption_configuration {
kms_key_id = aws_kms_key.kms_key_for_sfn.arn
type = %[2]q
kms_data_key_reuse_period_seconds = %[3]d
}
}
`, rName, rType, reusePeriodSeconds))
}

func testAccActivityConfig_encryptionConfigurationServiceOwnedKey(rName string, rType string) string {
return acctest.ConfigCompose(testAccActivityConfig_kmsBase(), fmt.Sprintf(`
resource "aws_sfn_activity" "test" {
name = %[1]q
encryption_configuration {
type = %[2]q
}
}
`, rName, rType))
}
49 changes: 49 additions & 0 deletions internal/service/sfn/flex.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

package sfn

import (
"github.com/aws/aws-sdk-go-v2/aws"
awstypes "github.com/aws/aws-sdk-go-v2/service/sfn/types"
"github.com/hashicorp/terraform-provider-aws/names"
)

func expandEncryptionConfiguration(tfMap map[string]interface{}) *awstypes.EncryptionConfiguration {
if tfMap == nil {
return nil
}

apiObject := &awstypes.EncryptionConfiguration{}

if v, ok := tfMap["kms_data_key_reuse_period_seconds"].(int); ok && v != 0 {
apiObject.KmsDataKeyReusePeriodSeconds = aws.Int32(int32(v))
}

if v, ok := tfMap[names.AttrKMSKeyID].(string); ok && v != "" {
apiObject.KmsKeyId = aws.String(v)
}

if v, ok := tfMap[names.AttrType].(string); ok && v != "" {
apiObject.Type = awstypes.EncryptionType(v)
}

return apiObject
}

func flattenEncryptionConfiguration(apiObject *awstypes.EncryptionConfiguration) map[string]interface{} {
if apiObject == nil {
return nil
}

tfMap := map[string]interface{}{
names.AttrKMSKeyID: aws.ToString(apiObject.KmsKeyId),
names.AttrType: apiObject.Type,
}

if v := apiObject.KmsDataKeyReusePeriodSeconds; v != nil {
tfMap["kms_data_key_reuse_period_seconds"] = aws.ToInt32(v)
}

return tfMap
}
69 changes: 57 additions & 12 deletions internal/service/sfn/state_machine.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,31 @@ func resourceStateMachine() *schema.Resource {
Type: schema.TypeString,
Computed: true,
},
names.AttrEncryptionConfiguration: {
Type: schema.TypeList,
Optional: true,
Computed: true,
MaxItems: 1,
Elem: &schema.Resource{
Schema: map[string]*schema.Schema{
"kms_data_key_reuse_period_seconds": {
Type: schema.TypeInt,
Optional: true,
ValidateFunc: validation.IntBetween(60, 900),
},
names.AttrKMSKeyID: {
Type: schema.TypeString,
Optional: true,
},
names.AttrType: {
Type: schema.TypeString,
Optional: true,
ValidateDiagFunc: enum.Validate[awstypes.EncryptionType](),
},
},
},
DiffSuppressFunc: verify.SuppressMissingOptionalConfigurationBlock,
},
names.AttrLoggingConfiguration: {
Type: schema.TypeList,
Optional: true,
Expand Down Expand Up @@ -117,23 +142,23 @@ func resourceStateMachine() *schema.Resource {
Default: false,
Optional: true,
},
"revision_id": {
Type: schema.TypeString,
Computed: true,
},
names.AttrRoleARN: {
Type: schema.TypeString,
Required: true,
ValidateFunc: verify.ValidARN,
},
"revision_id": {
"state_machine_version_arn": {
Type: schema.TypeString,
Computed: true,
},
names.AttrStatus: {
Type: schema.TypeString,
Computed: true,
},
"state_machine_version_arn": {
Type: schema.TypeString,
Computed: true,
},
names.AttrTags: tftags.TagsSchema(),
names.AttrTagsAll: tftags.TagsSchemaComputed(),
"tracing_configuration": {
Expand Down Expand Up @@ -182,6 +207,10 @@ func resourceStateMachineCreate(ctx context.Context, d *schema.ResourceData, met
Type: awstypes.StateMachineType(d.Get(names.AttrType).(string)),
}

if v, ok := d.GetOk(names.AttrEncryptionConfiguration); ok && len(v.([]interface{})) > 0 && v.([]interface{})[0] != nil {
input.EncryptionConfiguration = expandEncryptionConfiguration(v.([]interface{})[0].(map[string]interface{}))
}

if v, ok := d.GetOk(names.AttrLoggingConfiguration); ok && len(v.([]interface{})) > 0 && v.([]interface{})[0] != nil {
input.LoggingConfiguration = expandLoggingConfiguration(v.([]interface{})[0].(map[string]interface{}))
}
Expand All @@ -202,8 +231,7 @@ func resourceStateMachineCreate(ctx context.Context, d *schema.ResourceData, met
return sdkdiag.AppendErrorf(diags, "creating Step Functions State Machine (%s): %s", name, err)
}

arn := aws.ToString(outputRaw.(*sfn.CreateStateMachineOutput).StateMachineArn)
d.SetId(arn)
d.SetId(aws.ToString(outputRaw.(*sfn.CreateStateMachineOutput).StateMachineArn))

return append(diags, resourceStateMachineRead(ctx, d, meta)...)
}
Expand Down Expand Up @@ -232,6 +260,13 @@ func resourceStateMachineRead(ctx context.Context, d *schema.ResourceData, meta
}
d.Set("definition", output.Definition)
d.Set(names.AttrDescription, output.Description)
if output.EncryptionConfiguration != nil {
if err := d.Set(names.AttrEncryptionConfiguration, []interface{}{flattenEncryptionConfiguration(output.EncryptionConfiguration)}); err != nil {
return sdkdiag.AppendErrorf(diags, "setting encryption_configuration: %s", err)
}
} else {
d.Set(names.AttrEncryptionConfiguration, nil)
}
if output.LoggingConfiguration != nil {
if err := d.Set(names.AttrLoggingConfiguration, []interface{}{flattenLoggingConfiguration(output.LoggingConfiguration)}); err != nil {
return sdkdiag.AppendErrorf(diags, "setting logging_configuration: %s", err)
Expand All @@ -242,8 +277,8 @@ func resourceStateMachineRead(ctx context.Context, d *schema.ResourceData, meta
d.Set(names.AttrName, output.Name)
d.Set(names.AttrNamePrefix, create.NamePrefixFromName(aws.ToString(output.Name)))
d.Set("publish", d.Get("publish").(bool))
d.Set(names.AttrRoleARN, output.RoleArn)
d.Set("revision_id", output.RevisionId)
d.Set(names.AttrRoleARN, output.RoleArn)
d.Set(names.AttrStatus, output.Status)
if output.TracingConfiguration != nil {
if err := d.Set("tracing_configuration", []interface{}{flattenTracingConfiguration(output.TracingConfiguration)}); err != nil {
Expand Down Expand Up @@ -280,15 +315,18 @@ func resourceStateMachineUpdate(ctx context.Context, d *schema.ResourceData, met

if d.HasChangesExcept(names.AttrTags, names.AttrTagsAll) {
// "You must include at least one of definition or roleArn or you will receive a MissingRequiredParameter error"
publish := d.Get("publish").(bool)
input := &sfn.UpdateStateMachineInput{
Definition: aws.String(d.Get("definition").(string)),
Publish: publish,
RoleArn: aws.String(d.Get(names.AttrRoleARN).(string)),
StateMachineArn: aws.String(d.Id()),
Publish: d.Get("publish").(bool),
}

if v, ok := d.GetOk("publish"); ok && v == true {
input.VersionDescription = aws.String(d.Get("version_description").(string))
if d.HasChange(names.AttrEncryptionConfiguration) {
if v, ok := d.GetOk(names.AttrEncryptionConfiguration); ok && len(v.([]interface{})) > 0 && v.([]interface{})[0] != nil {
input.EncryptionConfiguration = expandEncryptionConfiguration(v.([]interface{})[0].(map[string]interface{}))
}
}

if d.HasChange(names.AttrLoggingConfiguration) {
Expand All @@ -303,6 +341,10 @@ func resourceStateMachineUpdate(ctx context.Context, d *schema.ResourceData, met
}
}

if publish {
input.VersionDescription = aws.String(d.Get("version_description").(string))
}

_, err := conn.UpdateStateMachine(ctx, input)

if err != nil {
Expand All @@ -322,7 +364,10 @@ func resourceStateMachineUpdate(ctx context.Context, d *schema.ResourceData, met
//d.HasChange("publish") && aws.Bool(output.Publish) != d.Get("publish").(bool) ||
d.HasChange("tracing_configuration.0.enabled") && output.TracingConfiguration != nil && output.TracingConfiguration.Enabled != d.Get("tracing_configuration.0.enabled").(bool) ||
d.HasChange("logging_configuration.0.include_execution_data") && output.LoggingConfiguration != nil && output.LoggingConfiguration.IncludeExecutionData != d.Get("logging_configuration.0.include_execution_data").(bool) ||
d.HasChange("logging_configuration.0.level") && output.LoggingConfiguration != nil && string(output.LoggingConfiguration.Level) != d.Get("logging_configuration.0.level").(string) {
d.HasChange("logging_configuration.0.level") && output.LoggingConfiguration != nil && string(output.LoggingConfiguration.Level) != d.Get("logging_configuration.0.level").(string) ||
d.HasChange("encryption_configuration.0.kms_key_id") && output.EncryptionConfiguration != nil && output.EncryptionConfiguration.KmsKeyId != nil && aws.ToString(output.EncryptionConfiguration.KmsKeyId) != d.Get("encryption_configuration.0.kms_key_id") ||
d.HasChange("encryption_configuration.0.encryption_type") && output.EncryptionConfiguration != nil && string(output.EncryptionConfiguration.Type) != d.Get("encryption_configuration.0.encryption_type").(string) ||
d.HasChange("encryption_configuration.0.kms_data_key_reuse_period_seconds") && output.EncryptionConfiguration != nil && output.EncryptionConfiguration.KmsDataKeyReusePeriodSeconds != nil && aws.ToInt32(output.EncryptionConfiguration.KmsDataKeyReusePeriodSeconds) != int32(d.Get("encryption_configuration.0.kms_data_key_reuse_period_seconds").(int)) {
return retry.RetryableError(fmt.Errorf("Step Functions State Machine (%s) eventual consistency", d.Id()))
}

Expand Down
Loading

0 comments on commit cbd253d

Please sign in to comment.