Skip to content

Commit

Permalink
Merge pull request #35479 from deepakbshetty/f-aws_sagemaker_endpoint…
Browse files Browse the repository at this point in the history
…_config-managed_instance_scaling

add managed_instance_scaling to sagemaker endpoint config production_variants
  • Loading branch information
ewbankkit authored Sep 9, 2024
2 parents ccf99b0 + 73ab815 commit 2eae23d
Show file tree
Hide file tree
Showing 4 changed files with 283 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .changelog/35479.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:enhancement
resource/aws_sagemaker_endpoint_configuration: Add `production_variants.managed_instance_scaling` and `shadow_production_variants.managed_instance_scaling` configuration blocks
```
110 changes: 110 additions & 0 deletions internal/service/sagemaker/endpoint_configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,34 @@ func resourceEndpointConfiguration() *schema.Resource {
},
},
},
"managed_instance_scaling": {
Type: schema.TypeList,
Optional: true,
MaxItems: 1,
ForceNew: true,
Elem: &schema.Resource{
Schema: map[string]*schema.Schema{
"max_instance_count": {
Type: schema.TypeInt,
Optional: true,
ForceNew: true,
ValidateFunc: validation.IntAtLeast(1),
},
"min_instance_count": {
Type: schema.TypeInt,
Optional: true,
ForceNew: true,
ValidateFunc: validation.IntAtLeast(1),
},
names.AttrStatus: {
Type: schema.TypeString,
Optional: true,
ForceNew: true,
ValidateDiagFunc: enum.Validate[awstypes.ManagedInstanceScalingStatus](),
},
},
},
},
"variant_name": {
Type: schema.TypeString,
Optional: true,
Expand Down Expand Up @@ -521,6 +549,34 @@ func resourceEndpointConfiguration() *schema.Resource {
},
},
},
"managed_instance_scaling": {
Type: schema.TypeList,
Optional: true,
MaxItems: 1,
ForceNew: true,
Elem: &schema.Resource{
Schema: map[string]*schema.Schema{
"max_instance_count": {
Type: schema.TypeInt,
Optional: true,
ForceNew: true,
ValidateFunc: validation.IntAtLeast(1),
},
"min_instance_count": {
Type: schema.TypeInt,
Optional: true,
ForceNew: true,
ValidateFunc: validation.IntAtLeast(1),
},
names.AttrStatus: {
Type: schema.TypeString,
Optional: true,
ForceNew: true,
ValidateDiagFunc: enum.Validate[awstypes.ManagedInstanceScalingStatus](),
},
},
},
},
"variant_name": {
Type: schema.TypeString,
Optional: true,
Expand Down Expand Up @@ -737,6 +793,10 @@ func expandProductionVariants(configured []interface{}) []awstypes.ProductionVar
l.EnableSSMAccess = aws.Bool(v)
}

if v, ok := data["managed_instance_scaling"].([]interface{}); ok && len(v) > 0 {
l.ManagedInstanceScaling = expandManagedInstanceScaling(v)
}

if v, ok := data["inference_ami_version"].(string); ok && v != "" {
l.InferenceAmiVersion = awstypes.ProductionVariantInferenceAmiVersion(v)
}
Expand Down Expand Up @@ -792,6 +852,10 @@ func flattenProductionVariants(list []awstypes.ProductionVariant) []map[string]i
l["enable_ssm_access"] = aws.ToBool(i.EnableSSMAccess)
}

if i.ManagedInstanceScaling != nil {
l["managed_instance_scaling"] = flattenManagedInstanceScaling(i.ManagedInstanceScaling)
}

result = append(result, l)
}
return result
Expand Down Expand Up @@ -1056,6 +1120,30 @@ func expandCoreDumpConfig(configured []interface{}) *awstypes.ProductionVariantC
return c
}

func expandManagedInstanceScaling(configured []interface{}) *awstypes.ProductionVariantManagedInstanceScaling {
if len(configured) == 0 {
return nil
}

m := configured[0].(map[string]interface{})

c := &awstypes.ProductionVariantManagedInstanceScaling{}

if v, ok := m[names.AttrStatus].(string); ok {
c.Status = awstypes.ManagedInstanceScalingStatus(v)
}

if v, ok := m["min_instance_count"].(int); ok && v > 0 {
c.MinInstanceCount = aws.Int32(int32(v))
}

if v, ok := m["max_instance_count"].(int); ok && v > 0 {
c.MaxInstanceCount = aws.Int32(int32(v))
}

return c
}

func flattenEndpointConfigAsyncInferenceConfig(config *awstypes.AsyncInferenceConfig) []map[string]interface{} {
if config == nil {
return []map[string]interface{}{}
Expand Down Expand Up @@ -1185,3 +1273,25 @@ func flattenCoreDumpConfig(config *awstypes.ProductionVariantCoreDumpConfig) []m

return []map[string]interface{}{cfg}
}

func flattenManagedInstanceScaling(config *awstypes.ProductionVariantManagedInstanceScaling) []map[string]interface{} {
if config == nil {
return []map[string]interface{}{}
}

cfg := map[string]interface{}{}

if config.Status != "" {
cfg[names.AttrStatus] = config.Status
}

if config.MinInstanceCount != nil {
cfg["min_instance_count"] = aws.ToInt32(config.MinInstanceCount)
}

if config.MaxInstanceCount != nil {
cfg["max_instance_count"] = aws.ToInt32(config.MaxInstanceCount)
}

return []map[string]interface{}{cfg}
}
163 changes: 163 additions & 0 deletions internal/service/sagemaker/endpoint_configuration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,41 @@ func TestAccSageMakerEndpointConfiguration_upgradeToEnableSSMAccess(t *testing.T
})
}

func TestAccSageMakerEndpointConfiguration_productionVariantsManagedInstanceScaling(t *testing.T) {
ctx := acctest.Context(t)
rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix)
resourceName := "aws_sagemaker_endpoint_configuration.test"

resource.ParallelTest(t, resource.TestCase{
PreCheck: func() { acctest.PreCheck(ctx, t) },
ErrorCheck: acctest.ErrorCheck(t, names.SageMakerServiceID),
ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories,
CheckDestroy: testAccCheckEndpointConfigurationDestroy(ctx),
Steps: []resource.TestStep{
{
Config: testAccEndpointConfigurationConfig_productionVariantsManagedInstanceScaling(rName),
Check: resource.ComposeAggregateTestCheckFunc(
testAccCheckEndpointConfigurationExists(ctx, resourceName),
resource.TestCheckResourceAttr(resourceName, names.AttrName, rName),
resource.TestCheckResourceAttr(resourceName, "production_variants.#", acctest.Ct1),
resource.TestCheckResourceAttr(resourceName, "production_variants.0.variant_name", "variant-1"),
resource.TestCheckResourceAttr(resourceName, "production_variants.0.model_name", rName),
resource.TestCheckResourceAttr(resourceName, "production_variants.0.initial_instance_count", acctest.Ct1),
resource.TestCheckResourceAttr(resourceName, "production_variants.0.instance_type", "ml.g5.4xlarge"),
resource.TestCheckResourceAttr(resourceName, "production_variants.0.managed_instance_scaling.0.status", "ENABLED"),
resource.TestCheckResourceAttr(resourceName, "production_variants.0.managed_instance_scaling.0.min_instance_count", acctest.Ct1),
resource.TestCheckResourceAttr(resourceName, "production_variants.0.managed_instance_scaling.0.max_instance_count", acctest.Ct2),
),
},
{
ResourceName: resourceName,
ImportState: true,
ImportStateVerify: true,
},
},
})
}

func testAccCheckEndpointConfigurationDestroy(ctx context.Context) resource.TestCheckFunc {
return func(s *terraform.State) error {
conn := acctest.Provider.Meta().(*conns.AWSClient).SageMakerClient(ctx)
Expand Down Expand Up @@ -1395,3 +1430,131 @@ resource "aws_sagemaker_endpoint_configuration" "test" {
}
`, rName))
}

func testAccEndpointConfigurationConfig_productionVariantsManagedInstanceScaling(rName string) string {
return acctest.ConfigCompose(fmt.Sprintf(`
data "aws_region" "current" {}
data "aws_partition" "current" {}
data "aws_sagemaker_prebuilt_ecr_image" "managed_instance_scaling_test" {
repository_name = "djl-inference"
image_tag = "0.27.0-deepspeed0.12.6-cu121"
}
data "aws_iam_policy_document" "managed_instance_scaling_test_policy" {
statement {
effect = "Allow"
actions = [
"cloudwatch:PutMetricData",
"logs:CreateLogStream",
"logs:PutLogEvents",
"logs:CreateLogGroup",
"logs:DescribeLogStreams",
"ecr:GetAuthorizationToken",
"ecr:BatchCheckLayerAvailability",
"ecr:GetDownloadUrlForLayer",
"ecr:BatchGetImage",
]
resources = [
"*",
]
}
statement {
effect = "Allow"
actions = [
"s3:GetObject",
"s3:ListBucket",
]
resources = [
"${aws_s3_bucket.managed_instance_scaling_test.arn}",
"${aws_s3_bucket.managed_instance_scaling_test.arn}/*",
]
}
}
resource "aws_iam_policy" "managed_instance_scaling_test" {
name = %[1]q
description = "Allow SageMaker to create model"
policy = data.aws_iam_policy_document.managed_instance_scaling_test_policy.json
}
resource "aws_iam_role" "managed_instance_scaling_test" {
name = %[1]q
path = "/"
assume_role_policy = data.aws_iam_policy_document.assume_role.json
}
data "aws_iam_policy_document" "assume_role" {
statement {
actions = ["sts:AssumeRole"]
principals {
type = "Service"
identifiers = ["sagemaker.amazonaws.com"]
}
}
}
resource "aws_iam_role_policy_attachment" "managed_instance_scaling_test" {
role = aws_iam_role.managed_instance_scaling_test.name
policy_arn = aws_iam_policy.managed_instance_scaling_test.arn
}
resource "aws_s3_bucket" "managed_instance_scaling_test" {
bucket = %[1]q
force_destroy = true
}
resource "aws_s3_object" "managed_instance_scaling_test" {
bucket = aws_s3_bucket.managed_instance_scaling_test.bucket
key = "model/inference.py"
content = "some-data"
}
resource "aws_sagemaker_model" "managed_instance_scaling_test" {
name = %[1]q
execution_role_arn = aws_iam_role.managed_instance_scaling_test.arn
primary_container {
image = data.aws_sagemaker_prebuilt_ecr_image.managed_instance_scaling_test.registry_path
model_data_source {
s3_data_source {
s3_data_type = "S3Prefix"
s3_uri = "s3://${aws_s3_object.managed_instance_scaling_test.bucket}/model/"
compression_type = "None"
}
}
}
depends_on = [
aws_iam_role_policy_attachment.managed_instance_scaling_test
]
}
resource "aws_sagemaker_endpoint_configuration" "test" {
name = %[1]q
production_variants {
variant_name = "variant-1"
model_name = aws_sagemaker_model.managed_instance_scaling_test.name
initial_instance_count = 1
instance_type = "ml.g5.4xlarge"
managed_instance_scaling {
status = "ENABLED"
min_instance_count = 1
max_instance_count = 2
}
routing_config {
routing_strategy = "LEAST_OUTSTANDING_REQUESTS"
}
model_data_download_timeout_in_seconds = 60
container_startup_health_check_timeout_in_seconds = 60
}
}
`, rName))
}
7 changes: 7 additions & 0 deletions website/docs/r/sagemaker_endpoint_configuration.html.markdown
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ This resource supports the following arguments:
* `model_name` - (Required) The name of the model to use.
* `routing_config` - (Optional) Sets how the endpoint routes incoming traffic. See [routing_config](#routing_config) below.
* `serverless_config` - (Optional) Specifies configuration for how an endpoint performs asynchronous inference.
* `managed_instance_scaling` - (Optional) Settings that control the range in the number of instances that the endpoint provisions as it scales up or down to accommodate traffic.
* `variant_name` - (Optional) The name of the variant. If omitted, Terraform will assign a random, unique name.
* `volume_size_in_gb` - (Optional) The size, in GB, of the ML storage volume attached to individual inference instance associated with the production variant. Valid values between `1` and `512`.

Expand All @@ -76,6 +77,12 @@ This resource supports the following arguments:
* `memory_size_in_mb` - (Required) The memory size of your serverless endpoint. Valid values are in 1 GB increments: `1024` MB, `2048` MB, `3072` MB, `4096` MB, `5120` MB, or `6144` MB.
* `provisioned_concurrency` - The amount of provisioned concurrency to allocate for the serverless endpoint. Should be less than or equal to `max_concurrency`. Valid values are between `1` and `200`.

#### managed_instance_scaling

* `status` - (Optional) Indicates whether managed instance scaling is enabled. Valid values are `ENABLED` and `DISABLED`.
* `min_instance_count` - (Optional) The minimum number of instances that the endpoint must retain when it scales down to accommodate a decrease in traffic.
* `max_instance_count` - (Optional) The maximum number of instances that the endpoint can provision when it scales up to accommodate an increase in traffic.

### data_capture_config

* `initial_sampling_percentage` - (Required) Portion of data to capture. Should be between 0 and 100.
Expand Down

0 comments on commit 2eae23d

Please sign in to comment.