diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/agent.py index 93ce7638d0..5af832f7b5 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/agent.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/agent.py @@ -83,7 +83,7 @@ async def create( async def get(self, resource_meta: SageMakerEndpointMetadata, **kwargs) -> Resource: try: - endpoint_status, idempotence_token = await self._call( + endpoint_status, _ = await self._call( method="describe_endpoint", config={"EndpointName": resource_meta.config.get("EndpointName")}, inputs=resource_meta.inputs, @@ -97,7 +97,7 @@ async def get(self, resource_meta: SageMakerEndpointMetadata, **kwargs) -> Resou if error_code == "ValidationException" and "Could not find endpoint" in error_message: raise Exception( "This might be due to resource limits being exceeded, preventing the creation of a new endpoint. Please check your resource usage and limits." - ) from e + ) raise e current_state = endpoint_status.get("EndpointStatus") @@ -109,10 +109,7 @@ async def get(self, resource_meta: SageMakerEndpointMetadata, **kwargs) -> Resou res = None if current_state == "InService": - res = { - "result": {"EndpointArn": endpoint_status.get("EndpointArn")}, - "idempotence_token": idempotence_token, - } + res = {"result": {"EndpointArn": endpoint_status.get("EndpointArn")}} return Resource(phase=flyte_phase, outputs=res, message=message) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_agent.py index 7b935e9101..5e34557e40 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_agent.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_agent.py @@ -59,6 +59,7 @@ async def do( boto3_object = Boto3AgentMixin(service=service, region=region) + result = None try: result, idempotence_token = await boto3_object._call( method=method, @@ -77,10 +78,16 @@ async def do( error_message, ).group(0) if arn: + arn_result = None + if method == "create_model": + arn_result = {"ModelArn": arn} + elif method == "create_endpoint_config": + arn_result = {"EndpointConfigArn": arn} + return Resource( phase=TaskExecution.SUCCEEDED, outputs={ - "result": {"result": f"Entity already exists: {arn}"}, + "result": arn_result if arn_result else {"result": f"Entity already exists {arn}."}, "idempotence_token": e.idempotence_token, }, ) @@ -100,6 +107,12 @@ async def do( outputs = {"result": {"result": None}} if result: + truncated_result = None + if method == "create_model": + truncated_result = {"ModelArn": result.get("ModelArn")} + elif method == "create_endpoint_config": + truncated_result = {"EndpointConfigArn": result.get("EndpointConfigArn")} + ctx = FlyteContextManager.current_context() builder = ctx.with_file_access( FileAccessProvider( @@ -113,7 +126,7 @@ async def do( literals={ "result": TypeEngine.to_literal( new_ctx, - result, + truncated_result if truncated_result else result, Annotated[dict, kwtypes(allow_pickle=True)], TypeEngine.to_literal_type(dict), ), diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_mixin.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_mixin.py index cf3cc0c14b..05dac9de59 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_mixin.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_mixin.py @@ -81,12 +81,12 @@ def update_dict_fn( except Exception: raise ValueError(f"Could not find the key {key} in {update_dict_copy}.") - if len(matches) > 1: - # Replace the placeholder in the original_dict - original_dict = original_dict.replace(f"{{{match}}}", update_dict_copy) - else: + if f"{{{match}}}" == original_dict: # If there's only one match, it needn't always be a string, so not replacing the original dict. return update_dict_copy + else: + # Replace the placeholder in the original_dict + original_dict = original_dict.replace(f"{{{match}}}", update_dict_copy) elif match == "idempotence_token" and idempotence_token: temp_dict = original_dict.replace(f"{{{match}}}", idempotence_token) if len(temp_dict) > 63: diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/task.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/task.py index 8714915776..afae35d3e0 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/task.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/task.py @@ -95,7 +95,7 @@ def __init__( super().__init__( name=name, task_type=self._TASK_TYPE, - interface=Interface(inputs=inputs, outputs=kwtypes(result=dict, idempotence_token=str)), + interface=Interface(inputs=inputs, outputs=kwtypes(result=dict)), **kwargs, ) self._config = config diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/workflow.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/workflow.py index 13b89e8ec4..be76a0a634 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/workflow.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/workflow.py @@ -38,6 +38,13 @@ def create_deployment_task( ) +def append_token(config, key, token, name): + if key in config: + config[key] += f"-{{{token}}}" + else: + config[key] = f"{name}-{{{token}}}" + + def create_sagemaker_deployment( name: str, model_config: Dict[str, Any], @@ -49,6 +56,7 @@ def create_sagemaker_deployment( endpoint_input_types: Optional[Dict[str, Type]] = None, region: Optional[str] = None, region_at_runtime: bool = False, + idempotence_token: bool = True, ) -> Workflow: """ Creates SageMaker model, endpoint config and endpoint. @@ -62,6 +70,7 @@ def create_sagemaker_deployment( :param endpoint_input_types: Mapping of SageMaker endpoint inputs to their types. :param region: The region for SageMaker API calls. :param region_at_runtime: Set this to True if you want to provide the region at runtime. + :param idempotence_token: Set this to False if you don't want the agent to automatically append a token/hash to the deployment names. """ if not any((region, region_at_runtime)): raise ValueError("Region parameter is required.") @@ -71,6 +80,21 @@ def create_sagemaker_deployment( if region_at_runtime: wf.add_workflow_input("region", str) + if idempotence_token: + append_token(model_config, "ModelName", "idempotence_token", name) + append_token(endpoint_config_config, "EndpointConfigName", "idempotence_token", name) + + if "ProductionVariants" in endpoint_config_config and endpoint_config_config["ProductionVariants"]: + append_token( + endpoint_config_config["ProductionVariants"][0], + "ModelName", + "inputs.idempotence_token", + name, + ) + + append_token(endpoint_config, "EndpointName", "idempotence_token", name) + append_token(endpoint_config, "EndpointConfigName", "inputs.idempotence_token", name) + inputs = { SageMakerModelTask: { "input_types": model_input_types, diff --git a/plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py b/plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py index fcdffe83fa..baf26fdffa 100644 --- a/plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py +++ b/plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py @@ -23,55 +23,42 @@ "mock_return_value", [ ( - { - "ResponseMetadata": { - "RequestId": "66f80391-348a-4ee0-9158-508914d16db2", - "HTTPStatusCode": 200.0, - "RetryAttempts": 0.0, - "HTTPHeaders": { - "content-type": "application/x-amz-json-1.1", - "date": "Wed, 31 Jan 2024 16:43:52 GMT", - "x-amzn-requestid": "66f80391-348a-4ee0-9158-508914d16db2", - "content-length": "114", - }, + ( + { + "EndpointConfigArn": "arn:aws:sagemaker:us-east-2:000000000:endpoint-config/sagemaker-xgboost-endpoint-config", }, - "EndpointConfigArn": "arn:aws:sagemaker:us-east-2:000000000:endpoint-config/sagemaker-xgboost-endpoint-config", - }, - idempotence_token, + idempotence_token, + ), + "create_endpoint_config", ), ( - { - "ResponseMetadata": { - "RequestId": "66f80391-348a-4ee0-9158-508914d16db2", - "HTTPStatusCode": 200.0, - "RetryAttempts": 0.0, - "HTTPHeaders": { - "content-type": "application/x-amz-json-1.1", - "date": "Wed, 31 Jan 2024 16:43:52 GMT", - "x-amzn-requestid": "66f80391-348a-4ee0-9158-508914d16db2", - "content-length": "114", - }, + ( + { + "pickle_check": datetime(2024, 5, 5), + "Location": "http://examplebucket.s3.amazonaws.com/", }, - "pickle_check": datetime(2024, 5, 5), - "EndpointConfigArn": "arn:aws:sagemaker:us-east-2:000000000:endpoint-config/sagemaker-xgboost-endpoint-config", - }, - idempotence_token, + idempotence_token, + ), + "create_bucket", ), - (None, idempotence_token), + ((None, idempotence_token), "create_endpoint_config"), ( - CustomException( - message="An error occurred", - idempotence_token=idempotence_token, - original_exception=ClientError( - error_response={ - "Error": { - "Code": "ValidationException", - "Message": "Cannot create already existing endpoint 'arn:aws:sagemaker:us-east-2:123456789:endpoint/stable-diffusion-endpoint-non-finetuned-06716dbe4b2c68e7'", - } - }, - operation_name="DescribeEndpoint", - ), - ) + ( + CustomException( + message="An error occurred", + idempotence_token=idempotence_token, + original_exception=ClientError( + error_response={ + "Error": { + "Code": "ValidationException", + "Message": "Cannot create already existing endpoint 'arn:aws:sagemaker:us-east-2:123456789:endpoint/stable-diffusion-endpoint-non-finetuned-06716dbe4b2c68e7'", + } + }, + operation_name="DescribeEndpoint", + ), + ) + ), + "create_endpoint_config", ), ], ) @@ -79,7 +66,7 @@ "flytekitplugins.awssagemaker_inference.boto3_agent.Boto3AgentMixin._call", ) async def test_agent(mock_boto_call, mock_return_value): - mock_boto_call.return_value = mock_return_value + mock_boto_call.return_value = mock_return_value[0] agent = AgentRegistry.get_agent("boto") task_id = Identifier( @@ -106,7 +93,7 @@ async def test_agent(mock_boto_call, mock_return_value): }, }, "region": "us-east-2", - "method": "create_endpoint_config", + "method": mock_return_value[1], "images": None, } task_metadata = TaskMetadata( @@ -149,8 +136,8 @@ async def test_agent(mock_boto_call, mock_return_value): ctx = FlyteContext.current_context() output_prefix = ctx.file_access.get_random_remote_directory() - if isinstance(mock_return_value, Exception): - mock_boto_call.side_effect = mock_return_value + if isinstance(mock_return_value[0], Exception): + mock_boto_call.side_effect = mock_return_value[0] resource = await agent.do( task_template=task_template, @@ -158,7 +145,7 @@ async def test_agent(mock_boto_call, mock_return_value): output_prefix=output_prefix, ) assert resource.outputs["result"] == { - "result": f"Entity already exists: arn:aws:sagemaker:us-east-2:123456789:endpoint/stable-diffusion-endpoint-non-finetuned-06716dbe4b2c68e7" + "EndpointConfigArn": "arn:aws:sagemaker:us-east-2:123456789:endpoint/stable-diffusion-endpoint-non-finetuned-06716dbe4b2c68e7" } assert resource.outputs["idempotence_token"] == idempotence_token return @@ -169,9 +156,9 @@ async def test_agent(mock_boto_call, mock_return_value): assert resource.phase == TaskExecution.SUCCEEDED - if mock_return_value[0]: + if mock_return_value[0][0]: outputs = literal_map_string_repr(resource.outputs) - if "pickle_check" in mock_return_value[0]: + if "pickle_check" in mock_return_value[0][0]: assert "pickle_file" in outputs["result"] else: assert ( @@ -179,5 +166,5 @@ async def test_agent(mock_boto_call, mock_return_value): == "arn:aws:sagemaker:us-east-2:000000000:endpoint-config/sagemaker-xgboost-endpoint-config" ) assert outputs["idempotence_token"] == "74443947857331f7" - elif mock_return_value[0] is None: + elif mock_return_value[0][0] is None: assert resource.outputs["result"] == {"result": None} diff --git a/plugins/flytekit-aws-sagemaker/tests/test_inference_agent.py b/plugins/flytekit-aws-sagemaker/tests/test_inference_agent.py index b3c8cba2e6..076100f60c 100644 --- a/plugins/flytekit-aws-sagemaker/tests/test_inference_agent.py +++ b/plugins/flytekit-aws-sagemaker/tests/test_inference_agent.py @@ -171,7 +171,6 @@ async def test_agent(mock_boto_call, mock_return_value): resource.outputs["result"]["EndpointArn"] == "arn:aws:sagemaker:us-east-2:1234567890:endpoint/sagemaker-xgboost-endpoint" ) - assert resource.outputs["idempotence_token"] == idempotence_token # DELETE delete_response = await agent.delete(metadata) diff --git a/plugins/flytekit-aws-sagemaker/tests/test_inference_task.py b/plugins/flytekit-aws-sagemaker/tests/test_inference_task.py index f74e0cc4b6..5e72ca79ed 100644 --- a/plugins/flytekit-aws-sagemaker/tests/test_inference_task.py +++ b/plugins/flytekit-aws-sagemaker/tests/test_inference_task.py @@ -73,7 +73,7 @@ kwtypes(endpoint_name=str, endpoint_config_name=str), None, 2, - 2, + 1, "us-east-2", SageMakerEndpointTask, ),