Skip to content

Commit

Permalink
truncate sagemaker agent outputs and automate idempotence token handl…
Browse files Browse the repository at this point in the history
…ing (flyteorg#2588)

* truncate sagemaker agent outputs

Signed-off-by: Samhita Alla <[email protected]>

* fix tests and update agent output

Signed-off-by: Samhita Alla <[email protected]>

* lint

Signed-off-by: Samhita Alla <[email protected]>

* fix test

Signed-off-by: Samhita Alla <[email protected]>

* add idempotence token to workflow

Signed-off-by: Samhita Alla <[email protected]>

* fix type

Signed-off-by: Samhita Alla <[email protected]>

* fix mixin

Signed-off-by: Samhita Alla <[email protected]>

* modify output handler

Signed-off-by: Samhita Alla <[email protected]>

---------

Signed-off-by: Samhita Alla <[email protected]>
Signed-off-by: mao3267 <[email protected]>
  • Loading branch information
samhita-alla authored and mao3267 committed Jul 29, 2024
1 parent 4bb172f commit bba4e97
Show file tree
Hide file tree
Showing 8 changed files with 86 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ async def create(

async def get(self, resource_meta: SageMakerEndpointMetadata, **kwargs) -> Resource:
try:
endpoint_status, idempotence_token = await self._call(
endpoint_status, _ = await self._call(
method="describe_endpoint",
config={"EndpointName": resource_meta.config.get("EndpointName")},
inputs=resource_meta.inputs,
Expand All @@ -97,7 +97,7 @@ async def get(self, resource_meta: SageMakerEndpointMetadata, **kwargs) -> Resou
if error_code == "ValidationException" and "Could not find endpoint" in error_message:
raise Exception(
"This might be due to resource limits being exceeded, preventing the creation of a new endpoint. Please check your resource usage and limits."
) from e
)
raise e

current_state = endpoint_status.get("EndpointStatus")
Expand All @@ -109,10 +109,7 @@ async def get(self, resource_meta: SageMakerEndpointMetadata, **kwargs) -> Resou

res = None
if current_state == "InService":
res = {
"result": {"EndpointArn": endpoint_status.get("EndpointArn")},
"idempotence_token": idempotence_token,
}
res = {"result": {"EndpointArn": endpoint_status.get("EndpointArn")}}

return Resource(phase=flyte_phase, outputs=res, message=message)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ async def do(

boto3_object = Boto3AgentMixin(service=service, region=region)

result = None
try:
result, idempotence_token = await boto3_object._call(
method=method,
Expand All @@ -77,10 +78,16 @@ async def do(
error_message,
).group(0)
if arn:
arn_result = None
if method == "create_model":
arn_result = {"ModelArn": arn}
elif method == "create_endpoint_config":
arn_result = {"EndpointConfigArn": arn}

return Resource(
phase=TaskExecution.SUCCEEDED,
outputs={
"result": {"result": f"Entity already exists: {arn}"},
"result": arn_result if arn_result else {"result": f"Entity already exists {arn}."},
"idempotence_token": e.idempotence_token,
},
)
Expand All @@ -100,6 +107,12 @@ async def do(

outputs = {"result": {"result": None}}
if result:
truncated_result = None
if method == "create_model":
truncated_result = {"ModelArn": result.get("ModelArn")}
elif method == "create_endpoint_config":
truncated_result = {"EndpointConfigArn": result.get("EndpointConfigArn")}

ctx = FlyteContextManager.current_context()
builder = ctx.with_file_access(
FileAccessProvider(
Expand All @@ -113,7 +126,7 @@ async def do(
literals={
"result": TypeEngine.to_literal(
new_ctx,
result,
truncated_result if truncated_result else result,
Annotated[dict, kwtypes(allow_pickle=True)],
TypeEngine.to_literal_type(dict),
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,12 @@ def update_dict_fn(
except Exception:
raise ValueError(f"Could not find the key {key} in {update_dict_copy}.")

if len(matches) > 1:
# Replace the placeholder in the original_dict
original_dict = original_dict.replace(f"{{{match}}}", update_dict_copy)
else:
if f"{{{match}}}" == original_dict:
# If there's only one match, it needn't always be a string, so not replacing the original dict.
return update_dict_copy
else:
# Replace the placeholder in the original_dict
original_dict = original_dict.replace(f"{{{match}}}", update_dict_copy)
elif match == "idempotence_token" and idempotence_token:
temp_dict = original_dict.replace(f"{{{match}}}", idempotence_token)
if len(temp_dict) > 63:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __init__(
super().__init__(
name=name,
task_type=self._TASK_TYPE,
interface=Interface(inputs=inputs, outputs=kwtypes(result=dict, idempotence_token=str)),
interface=Interface(inputs=inputs, outputs=kwtypes(result=dict)),
**kwargs,
)
self._config = config
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ def create_deployment_task(
)


def append_token(config, key, token, name):
if key in config:
config[key] += f"-{{{token}}}"
else:
config[key] = f"{name}-{{{token}}}"


def create_sagemaker_deployment(
name: str,
model_config: Dict[str, Any],
Expand All @@ -49,6 +56,7 @@ def create_sagemaker_deployment(
endpoint_input_types: Optional[Dict[str, Type]] = None,
region: Optional[str] = None,
region_at_runtime: bool = False,
idempotence_token: bool = True,
) -> Workflow:
"""
Creates SageMaker model, endpoint config and endpoint.
Expand All @@ -62,6 +70,7 @@ def create_sagemaker_deployment(
:param endpoint_input_types: Mapping of SageMaker endpoint inputs to their types.
:param region: The region for SageMaker API calls.
:param region_at_runtime: Set this to True if you want to provide the region at runtime.
:param idempotence_token: Set this to False if you don't want the agent to automatically append a token/hash to the deployment names.
"""
if not any((region, region_at_runtime)):
raise ValueError("Region parameter is required.")
Expand All @@ -71,6 +80,21 @@ def create_sagemaker_deployment(
if region_at_runtime:
wf.add_workflow_input("region", str)

if idempotence_token:
append_token(model_config, "ModelName", "idempotence_token", name)
append_token(endpoint_config_config, "EndpointConfigName", "idempotence_token", name)

if "ProductionVariants" in endpoint_config_config and endpoint_config_config["ProductionVariants"]:
append_token(
endpoint_config_config["ProductionVariants"][0],
"ModelName",
"inputs.idempotence_token",
name,
)

append_token(endpoint_config, "EndpointName", "idempotence_token", name)
append_token(endpoint_config, "EndpointConfigName", "inputs.idempotence_token", name)

inputs = {
SageMakerModelTask: {
"input_types": model_input_types,
Expand Down
89 changes: 38 additions & 51 deletions plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,63 +23,50 @@
"mock_return_value",
[
(
{
"ResponseMetadata": {
"RequestId": "66f80391-348a-4ee0-9158-508914d16db2",
"HTTPStatusCode": 200.0,
"RetryAttempts": 0.0,
"HTTPHeaders": {
"content-type": "application/x-amz-json-1.1",
"date": "Wed, 31 Jan 2024 16:43:52 GMT",
"x-amzn-requestid": "66f80391-348a-4ee0-9158-508914d16db2",
"content-length": "114",
},
(
{
"EndpointConfigArn": "arn:aws:sagemaker:us-east-2:000000000:endpoint-config/sagemaker-xgboost-endpoint-config",
},
"EndpointConfigArn": "arn:aws:sagemaker:us-east-2:000000000:endpoint-config/sagemaker-xgboost-endpoint-config",
},
idempotence_token,
idempotence_token,
),
"create_endpoint_config",
),
(
{
"ResponseMetadata": {
"RequestId": "66f80391-348a-4ee0-9158-508914d16db2",
"HTTPStatusCode": 200.0,
"RetryAttempts": 0.0,
"HTTPHeaders": {
"content-type": "application/x-amz-json-1.1",
"date": "Wed, 31 Jan 2024 16:43:52 GMT",
"x-amzn-requestid": "66f80391-348a-4ee0-9158-508914d16db2",
"content-length": "114",
},
(
{
"pickle_check": datetime(2024, 5, 5),
"Location": "http://examplebucket.s3.amazonaws.com/",
},
"pickle_check": datetime(2024, 5, 5),
"EndpointConfigArn": "arn:aws:sagemaker:us-east-2:000000000:endpoint-config/sagemaker-xgboost-endpoint-config",
},
idempotence_token,
idempotence_token,
),
"create_bucket",
),
(None, idempotence_token),
((None, idempotence_token), "create_endpoint_config"),
(
CustomException(
message="An error occurred",
idempotence_token=idempotence_token,
original_exception=ClientError(
error_response={
"Error": {
"Code": "ValidationException",
"Message": "Cannot create already existing endpoint 'arn:aws:sagemaker:us-east-2:123456789:endpoint/stable-diffusion-endpoint-non-finetuned-06716dbe4b2c68e7'",
}
},
operation_name="DescribeEndpoint",
),
)
(
CustomException(
message="An error occurred",
idempotence_token=idempotence_token,
original_exception=ClientError(
error_response={
"Error": {
"Code": "ValidationException",
"Message": "Cannot create already existing endpoint 'arn:aws:sagemaker:us-east-2:123456789:endpoint/stable-diffusion-endpoint-non-finetuned-06716dbe4b2c68e7'",
}
},
operation_name="DescribeEndpoint",
),
)
),
"create_endpoint_config",
),
],
)
@mock.patch(
"flytekitplugins.awssagemaker_inference.boto3_agent.Boto3AgentMixin._call",
)
async def test_agent(mock_boto_call, mock_return_value):
mock_boto_call.return_value = mock_return_value
mock_boto_call.return_value = mock_return_value[0]

agent = AgentRegistry.get_agent("boto")
task_id = Identifier(
Expand All @@ -106,7 +93,7 @@ async def test_agent(mock_boto_call, mock_return_value):
},
},
"region": "us-east-2",
"method": "create_endpoint_config",
"method": mock_return_value[1],
"images": None,
}
task_metadata = TaskMetadata(
Expand Down Expand Up @@ -149,16 +136,16 @@ async def test_agent(mock_boto_call, mock_return_value):
ctx = FlyteContext.current_context()
output_prefix = ctx.file_access.get_random_remote_directory()

if isinstance(mock_return_value, Exception):
mock_boto_call.side_effect = mock_return_value
if isinstance(mock_return_value[0], Exception):
mock_boto_call.side_effect = mock_return_value[0]

resource = await agent.do(
task_template=task_template,
inputs=task_inputs,
output_prefix=output_prefix,
)
assert resource.outputs["result"] == {
"result": f"Entity already exists: arn:aws:sagemaker:us-east-2:123456789:endpoint/stable-diffusion-endpoint-non-finetuned-06716dbe4b2c68e7"
"EndpointConfigArn": "arn:aws:sagemaker:us-east-2:123456789:endpoint/stable-diffusion-endpoint-non-finetuned-06716dbe4b2c68e7"
}
assert resource.outputs["idempotence_token"] == idempotence_token
return
Expand All @@ -169,15 +156,15 @@ async def test_agent(mock_boto_call, mock_return_value):

assert resource.phase == TaskExecution.SUCCEEDED

if mock_return_value[0]:
if mock_return_value[0][0]:
outputs = literal_map_string_repr(resource.outputs)
if "pickle_check" in mock_return_value[0]:
if "pickle_check" in mock_return_value[0][0]:
assert "pickle_file" in outputs["result"]
else:
assert (
outputs["result"]["EndpointConfigArn"]
== "arn:aws:sagemaker:us-east-2:000000000:endpoint-config/sagemaker-xgboost-endpoint-config"
)
assert outputs["idempotence_token"] == "74443947857331f7"
elif mock_return_value[0] is None:
elif mock_return_value[0][0] is None:
assert resource.outputs["result"] == {"result": None}
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,6 @@ async def test_agent(mock_boto_call, mock_return_value):
resource.outputs["result"]["EndpointArn"]
== "arn:aws:sagemaker:us-east-2:1234567890:endpoint/sagemaker-xgboost-endpoint"
)
assert resource.outputs["idempotence_token"] == idempotence_token

# DELETE
delete_response = await agent.delete(metadata)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
kwtypes(endpoint_name=str, endpoint_config_name=str),
None,
2,
2,
1,
"us-east-2",
SageMakerEndpointTask,
),
Expand Down

0 comments on commit bba4e97

Please sign in to comment.