Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add managed_instance_scaling to sagemaker endpoint config production_variants #35479

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .changelog/35479.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:enhancement
resource/aws_sagemaker_endpoint_configuration: Add `production_variants.managed_instance_scaling` and `shadow_production_variants.managed_instance_scaling` configuration blocks
```
110 changes: 110 additions & 0 deletions internal/service/sagemaker/endpoint_configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,34 @@ func resourceEndpointConfiguration() *schema.Resource {
},
},
},
"managed_instance_scaling": {
Type: schema.TypeList,
Optional: true,
MaxItems: 1,
ForceNew: true,
Elem: &schema.Resource{
Schema: map[string]*schema.Schema{
"max_instance_count": {
Type: schema.TypeInt,
Optional: true,
ForceNew: true,
ValidateFunc: validation.IntAtLeast(1),
},
"min_instance_count": {
Type: schema.TypeInt,
Optional: true,
ForceNew: true,
ValidateFunc: validation.IntAtLeast(1),
},
names.AttrStatus: {
Type: schema.TypeString,
Optional: true,
ForceNew: true,
ValidateDiagFunc: enum.Validate[awstypes.ManagedInstanceScalingStatus](),
},
},
},
},
"variant_name": {
Type: schema.TypeString,
Optional: true,
Expand Down Expand Up @@ -521,6 +549,34 @@ func resourceEndpointConfiguration() *schema.Resource {
},
},
},
"managed_instance_scaling": {
Type: schema.TypeList,
Optional: true,
MaxItems: 1,
ForceNew: true,
Elem: &schema.Resource{
Schema: map[string]*schema.Schema{
"max_instance_count": {
Type: schema.TypeInt,
Optional: true,
ForceNew: true,
ValidateFunc: validation.IntAtLeast(1),
},
"min_instance_count": {
Type: schema.TypeInt,
Optional: true,
ForceNew: true,
ValidateFunc: validation.IntAtLeast(1),
},
names.AttrStatus: {
Type: schema.TypeString,
Optional: true,
ForceNew: true,
ValidateDiagFunc: enum.Validate[awstypes.ManagedInstanceScalingStatus](),
},
},
},
},
"variant_name": {
Type: schema.TypeString,
Optional: true,
Expand Down Expand Up @@ -737,6 +793,10 @@ func expandProductionVariants(configured []interface{}) []awstypes.ProductionVar
l.EnableSSMAccess = aws.Bool(v)
}

if v, ok := data["managed_instance_scaling"].([]interface{}); ok && len(v) > 0 {
l.ManagedInstanceScaling = expandManagedInstanceScaling(v)
}

if v, ok := data["inference_ami_version"].(string); ok && v != "" {
l.InferenceAmiVersion = awstypes.ProductionVariantInferenceAmiVersion(v)
}
Expand Down Expand Up @@ -792,6 +852,10 @@ func flattenProductionVariants(list []awstypes.ProductionVariant) []map[string]i
l["enable_ssm_access"] = aws.ToBool(i.EnableSSMAccess)
}

if i.ManagedInstanceScaling != nil {
l["managed_instance_scaling"] = flattenManagedInstanceScaling(i.ManagedInstanceScaling)
}

result = append(result, l)
}
return result
Expand Down Expand Up @@ -1056,6 +1120,30 @@ func expandCoreDumpConfig(configured []interface{}) *awstypes.ProductionVariantC
return c
}

func expandManagedInstanceScaling(configured []interface{}) *awstypes.ProductionVariantManagedInstanceScaling {
if len(configured) == 0 {
return nil
}

m := configured[0].(map[string]interface{})

c := &awstypes.ProductionVariantManagedInstanceScaling{}

if v, ok := m[names.AttrStatus].(string); ok {
c.Status = awstypes.ManagedInstanceScalingStatus(v)
}

if v, ok := m["min_instance_count"].(int); ok && v > 0 {
c.MinInstanceCount = aws.Int32(int32(v))
}

if v, ok := m["max_instance_count"].(int); ok && v > 0 {
c.MaxInstanceCount = aws.Int32(int32(v))
}

return c
}

func flattenEndpointConfigAsyncInferenceConfig(config *awstypes.AsyncInferenceConfig) []map[string]interface{} {
if config == nil {
return []map[string]interface{}{}
Expand Down Expand Up @@ -1185,3 +1273,25 @@ func flattenCoreDumpConfig(config *awstypes.ProductionVariantCoreDumpConfig) []m

return []map[string]interface{}{cfg}
}

func flattenManagedInstanceScaling(config *awstypes.ProductionVariantManagedInstanceScaling) []map[string]interface{} {
if config == nil {
return []map[string]interface{}{}
}

cfg := map[string]interface{}{}

if config.Status != "" {
cfg[names.AttrStatus] = config.Status
}

if config.MinInstanceCount != nil {
cfg["min_instance_count"] = aws.ToInt32(config.MinInstanceCount)
}

if config.MaxInstanceCount != nil {
cfg["max_instance_count"] = aws.ToInt32(config.MaxInstanceCount)
}

return []map[string]interface{}{cfg}
}
163 changes: 163 additions & 0 deletions internal/service/sagemaker/endpoint_configuration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,41 @@ func TestAccSageMakerEndpointConfiguration_upgradeToEnableSSMAccess(t *testing.T
})
}

func TestAccSageMakerEndpointConfiguration_productionVariantsManagedInstanceScaling(t *testing.T) {
ctx := acctest.Context(t)
rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix)
resourceName := "aws_sagemaker_endpoint_configuration.test"

resource.ParallelTest(t, resource.TestCase{
PreCheck: func() { acctest.PreCheck(ctx, t) },
ErrorCheck: acctest.ErrorCheck(t, names.SageMakerServiceID),
ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories,
CheckDestroy: testAccCheckEndpointConfigurationDestroy(ctx),
Steps: []resource.TestStep{
{
Config: testAccEndpointConfigurationConfig_productionVariantsManagedInstanceScaling(rName),
Check: resource.ComposeAggregateTestCheckFunc(
testAccCheckEndpointConfigurationExists(ctx, resourceName),
resource.TestCheckResourceAttr(resourceName, names.AttrName, rName),
resource.TestCheckResourceAttr(resourceName, "production_variants.#", acctest.Ct1),
resource.TestCheckResourceAttr(resourceName, "production_variants.0.variant_name", "variant-1"),
resource.TestCheckResourceAttr(resourceName, "production_variants.0.model_name", rName),
resource.TestCheckResourceAttr(resourceName, "production_variants.0.initial_instance_count", acctest.Ct1),
resource.TestCheckResourceAttr(resourceName, "production_variants.0.instance_type", "ml.g5.4xlarge"),
resource.TestCheckResourceAttr(resourceName, "production_variants.0.managed_instance_scaling.0.status", "ENABLED"),
resource.TestCheckResourceAttr(resourceName, "production_variants.0.managed_instance_scaling.0.min_instance_count", acctest.Ct1),
resource.TestCheckResourceAttr(resourceName, "production_variants.0.managed_instance_scaling.0.max_instance_count", acctest.Ct2),
),
},
{
ResourceName: resourceName,
ImportState: true,
ImportStateVerify: true,
},
},
})
}

func testAccCheckEndpointConfigurationDestroy(ctx context.Context) resource.TestCheckFunc {
return func(s *terraform.State) error {
conn := acctest.Provider.Meta().(*conns.AWSClient).SageMakerClient(ctx)
Expand Down Expand Up @@ -1395,3 +1430,131 @@ resource "aws_sagemaker_endpoint_configuration" "test" {
}
`, rName))
}

func testAccEndpointConfigurationConfig_productionVariantsManagedInstanceScaling(rName string) string {
return acctest.ConfigCompose(fmt.Sprintf(`
data "aws_region" "current" {}
data "aws_partition" "current" {}
data "aws_sagemaker_prebuilt_ecr_image" "managed_instance_scaling_test" {
repository_name = "djl-inference"
image_tag = "0.27.0-deepspeed0.12.6-cu121"
}

data "aws_iam_policy_document" "managed_instance_scaling_test_policy" {
statement {
effect = "Allow"

actions = [
"cloudwatch:PutMetricData",
"logs:CreateLogStream",
"logs:PutLogEvents",
"logs:CreateLogGroup",
"logs:DescribeLogStreams",
"ecr:GetAuthorizationToken",
"ecr:BatchCheckLayerAvailability",
"ecr:GetDownloadUrlForLayer",
"ecr:BatchGetImage",
]

resources = [
"*",
]
}

statement {
effect = "Allow"

actions = [
"s3:GetObject",
"s3:ListBucket",
]

resources = [
"${aws_s3_bucket.managed_instance_scaling_test.arn}",
"${aws_s3_bucket.managed_instance_scaling_test.arn}/*",
]
}
}

resource "aws_iam_policy" "managed_instance_scaling_test" {
name = %[1]q
description = "Allow SageMaker to create model"
policy = data.aws_iam_policy_document.managed_instance_scaling_test_policy.json
}

resource "aws_iam_role" "managed_instance_scaling_test" {
name = %[1]q
path = "/"
assume_role_policy = data.aws_iam_policy_document.assume_role.json
}

data "aws_iam_policy_document" "assume_role" {
statement {
actions = ["sts:AssumeRole"]

principals {
type = "Service"
identifiers = ["sagemaker.amazonaws.com"]
}
}
}

resource "aws_iam_role_policy_attachment" "managed_instance_scaling_test" {
role = aws_iam_role.managed_instance_scaling_test.name
policy_arn = aws_iam_policy.managed_instance_scaling_test.arn
}

resource "aws_s3_bucket" "managed_instance_scaling_test" {
bucket = %[1]q
force_destroy = true
}

resource "aws_s3_object" "managed_instance_scaling_test" {
bucket = aws_s3_bucket.managed_instance_scaling_test.bucket
key = "model/inference.py"
content = "some-data"
}

resource "aws_sagemaker_model" "managed_instance_scaling_test" {
name = %[1]q
execution_role_arn = aws_iam_role.managed_instance_scaling_test.arn
primary_container {
image = data.aws_sagemaker_prebuilt_ecr_image.managed_instance_scaling_test.registry_path
model_data_source {
s3_data_source {
s3_data_type = "S3Prefix"
s3_uri = "s3://${aws_s3_object.managed_instance_scaling_test.bucket}/model/"
compression_type = "None"
}
}
}
depends_on = [
aws_iam_role_policy_attachment.managed_instance_scaling_test
]
}

resource "aws_sagemaker_endpoint_configuration" "test" {
name = %[1]q

production_variants {
variant_name = "variant-1"
model_name = aws_sagemaker_model.managed_instance_scaling_test.name
initial_instance_count = 1
instance_type = "ml.g5.4xlarge"

managed_instance_scaling {
status = "ENABLED"
min_instance_count = 1
max_instance_count = 2
}

routing_config {
routing_strategy = "LEAST_OUTSTANDING_REQUESTS"
}

model_data_download_timeout_in_seconds = 60
container_startup_health_check_timeout_in_seconds = 60
}
}
`, rName))
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ This resource supports the following arguments:
* `model_name` - (Required) The name of the model to use.
* `routing_config` - (Optional) Sets how the endpoint routes incoming traffic. See [routing_config](#routing_config) below.
* `serverless_config` - (Optional) Specifies configuration for how an endpoint performs asynchronous inference.
* `managed_instance_scaling` - (Optional) Settings that control the range in the number of instances that the endpoint provisions as it scales up or down to accommodate traffic.
* `variant_name` - (Optional) The name of the variant. If omitted, Terraform will assign a random, unique name.
* `volume_size_in_gb` - (Optional) The size, in GB, of the ML storage volume attached to individual inference instance associated with the production variant. Valid values between `1` and `512`.

Expand All @@ -76,6 +77,12 @@ This resource supports the following arguments:
* `memory_size_in_mb` - (Required) The memory size of your serverless endpoint. Valid values are in 1 GB increments: `1024` MB, `2048` MB, `3072` MB, `4096` MB, `5120` MB, or `6144` MB.
* `provisioned_concurrency` - The amount of provisioned concurrency to allocate for the serverless endpoint. Should be less than or equal to `max_concurrency`. Valid values are between `1` and `200`.

#### managed_instance_scaling

* `status` - (Optional) Indicates whether managed instance scaling is enabled. Valid values are `ENABLED` and `DISABLED`.
* `min_instance_count` - (Optional) The minimum number of instances that the endpoint must retain when it scales down to accommodate a decrease in traffic.
* `max_instance_count` - (Optional) The maximum number of instances that the endpoint can provision when it scales up to accommodate an increase in traffic.

### data_capture_config

* `initial_sampling_percentage` - (Required) Portion of data to capture. Should be between 0 and 100.
Expand Down
Loading