diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-endpoint-config.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-endpoint-config.ts index 8fe5749506c94..a7ff0134e1827 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-endpoint-config.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-endpoint-config.ts @@ -80,7 +80,8 @@ export class SageMakerCreateEndpointConfig extends sfn.TaskStateBase { KmsKeyId: this.props.kmsKey?.keyId, ProductionVariants: this.props.productionVariants.map((variant) => ({ InitialInstanceCount: variant.initialInstanceCount ? variant.initialInstanceCount : 1, - InstanceType: `ml.${variant.instanceType}`, + InstanceType: sfn.JsonPath.isEncodedJsonPath(variant.instanceType.toString()) + ? variant.instanceType.toString() : `ml.${variant.instanceType}`, ModelName: variant.modelName, VariantName: variant.variantName, AcceleratorType: variant.acceleratorType, diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/create-endpoint-config.test.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/create-endpoint-config.test.ts index f2c24889d1d0e..c59eb59eb5a83 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/create-endpoint-config.test.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/create-endpoint-config.test.ts @@ -69,7 +69,7 @@ test('create complex endpoint config', () => { { initialInstanceCount: 1, initialVariantWeight: 0.2, - instanceType: ec2.InstanceType.of(ec2.InstanceClass.M4, ec2.InstanceSize.XLARGE), + instanceType: new ec2.InstanceType(sfn.JsonPath.stringAt('$.Endpoint.InstanceType')), modelName: sfn.JsonPath.stringAt('$.Endpoint.Model'), variantName: 'awesome-variant-2', }], @@ -110,7 +110,7 @@ test('create complex endpoint config', () => { { 'InitialInstanceCount': 1, 'InitialVariantWeight': 0.2, - 'InstanceType': 'ml.m4.xlarge', + 'InstanceType.$': '$.Endpoint.InstanceType', 'ModelName.$': '$.Endpoint.Model', 'VariantName': 'awesome-variant-2', }],