Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

truncate sagemaker agent outputs and automate idempotence token handling #2588

Merged
merged 8 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,7 @@ async def get(self, resource_meta: SageMakerEndpointMetadata, **kwargs) -> Resou

res = None
if current_state == "InService":
res = {
"result": {"EndpointArn": endpoint_status.get("EndpointArn")},
"idempotence_token": idempotence_token,
}
res = {"result": {"EndpointArn": endpoint_status.get("EndpointArn")}}

return Resource(phase=flyte_phase, outputs=res, message=message)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ async def do(

boto3_object = Boto3AgentMixin(service=service, region=region)

result = None
try:
result, idempotence_token = await boto3_object._call(
method=method,
Expand All @@ -77,10 +78,15 @@ async def do(
error_message,
).group(0)
if arn:
if method == "create_model":
result = {"ModelArn": arn}
elif method == "create_endpoint_config":
result = {"EndpointConfigArn": arn}

return Resource(
phase=TaskExecution.SUCCEEDED,
outputs={
"result": {"result": f"Entity already exists: {arn}"},
"result": result,
"idempotence_token": e.idempotence_token,
},
)
Expand All @@ -100,6 +106,12 @@ async def do(

outputs = {"result": {"result": None}}
if result:
truncated_result = None
if method == "create_model":
truncated_result = {"ModelArn": result.get("ModelArn")}
elif method == "create_endpoint_config":
truncated_result = {"EndpointConfigArn": result.get("EndpointConfigArn")}

ctx = FlyteContextManager.current_context()
builder = ctx.with_file_access(
FileAccessProvider(
Expand All @@ -113,7 +125,7 @@ async def do(
literals={
"result": TypeEngine.to_literal(
new_ctx,
result,
truncated_result if truncated_result else result,
Annotated[dict, kwtypes(allow_pickle=True)],
TypeEngine.to_literal_type(dict),
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,12 @@ def update_dict_fn(
except Exception:
raise ValueError(f"Could not find the key {key} in {update_dict_copy}.")

if len(matches) > 1:
# Replace the placeholder in the original_dict
original_dict = original_dict.replace(f"{{{match}}}", update_dict_copy)
else:
if f"{{{match}}}" == original_dict:
# If there's only one match, it needn't always be a string, so not replacing the original dict.
return update_dict_copy
else:
# Replace the placeholder in the original_dict
original_dict = original_dict.replace(f"{{{match}}}", update_dict_copy)
elif match == "idempotence_token" and idempotence_token:
temp_dict = original_dict.replace(f"{{{match}}}", idempotence_token)
if len(temp_dict) > 63:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __init__(
super().__init__(
name=name,
task_type=self._TASK_TYPE,
interface=Interface(inputs=inputs, outputs=kwtypes(result=dict, idempotence_token=str)),
interface=Interface(inputs=inputs, outputs=kwtypes(result=dict)),
**kwargs,
)
self._config = config
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ def create_deployment_task(
)


def append_token(config, key, token, name):
if key in config:
config[key] += f"-{{{token}}}"
else:
config[key] = f"{name}-{{{token}}}"


def create_sagemaker_deployment(
name: str,
model_config: Dict[str, Any],
Expand All @@ -49,6 +56,7 @@ def create_sagemaker_deployment(
endpoint_input_types: Optional[Dict[str, Type]] = None,
region: Optional[str] = None,
region_at_runtime: bool = False,
idempotence_token: bool = True,
) -> Workflow:
"""
Creates SageMaker model, endpoint config and endpoint.
Expand All @@ -62,6 +70,7 @@ def create_sagemaker_deployment(
:param endpoint_input_types: Mapping of SageMaker endpoint inputs to their types.
:param region: The region for SageMaker API calls.
:param region_at_runtime: Set this to True if you want to provide the region at runtime.
:param idempotence_token: Set this to False if you don't want the agent to automatically append a token/hash to the deployment names.
"""
if not any((region, region_at_runtime)):
raise ValueError("Region parameter is required.")
Expand All @@ -71,6 +80,21 @@ def create_sagemaker_deployment(
if region_at_runtime:
wf.add_workflow_input("region", str)

if idempotence_token:
append_token(model_config, "ModelName", "idempotence_token", name)
append_token(endpoint_config_config, "EndpointConfigName", "idempotence_token", name)

if "ProductionVariants" in endpoint_config_config and endpoint_config_config["ProductionVariants"]:
append_token(
endpoint_config_config["ProductionVariants"][0],
"ModelName",
"inputs.idempotence_token",
name,
)

append_token(endpoint_config, "EndpointName", "idempotence_token", name)
append_token(endpoint_config, "EndpointConfigName", "inputs.idempotence_token", name)

inputs = {
SageMakerModelTask: {
"input_types": model_input_types,
Expand Down
89 changes: 38 additions & 51 deletions plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,63 +23,50 @@
"mock_return_value",
[
(
{
"ResponseMetadata": {
"RequestId": "66f80391-348a-4ee0-9158-508914d16db2",
"HTTPStatusCode": 200.0,
"RetryAttempts": 0.0,
"HTTPHeaders": {
"content-type": "application/x-amz-json-1.1",
"date": "Wed, 31 Jan 2024 16:43:52 GMT",
"x-amzn-requestid": "66f80391-348a-4ee0-9158-508914d16db2",
"content-length": "114",
},
(
{
"EndpointConfigArn": "arn:aws:sagemaker:us-east-2:000000000:endpoint-config/sagemaker-xgboost-endpoint-config",
},
"EndpointConfigArn": "arn:aws:sagemaker:us-east-2:000000000:endpoint-config/sagemaker-xgboost-endpoint-config",
},
idempotence_token,
idempotence_token,
),
"create_endpoint_config",
),
(
{
"ResponseMetadata": {
"RequestId": "66f80391-348a-4ee0-9158-508914d16db2",
"HTTPStatusCode": 200.0,
"RetryAttempts": 0.0,
"HTTPHeaders": {
"content-type": "application/x-amz-json-1.1",
"date": "Wed, 31 Jan 2024 16:43:52 GMT",
"x-amzn-requestid": "66f80391-348a-4ee0-9158-508914d16db2",
"content-length": "114",
},
(
{
"pickle_check": datetime(2024, 5, 5),
"Location": "http://examplebucket.s3.amazonaws.com/",
},
"pickle_check": datetime(2024, 5, 5),
"EndpointConfigArn": "arn:aws:sagemaker:us-east-2:000000000:endpoint-config/sagemaker-xgboost-endpoint-config",
},
idempotence_token,
idempotence_token,
),
"create_bucket",
),
(None, idempotence_token),
((None, idempotence_token), "create_endpoint_config"),
(
CustomException(
message="An error occurred",
idempotence_token=idempotence_token,
original_exception=ClientError(
error_response={
"Error": {
"Code": "ValidationException",
"Message": "Cannot create already existing endpoint 'arn:aws:sagemaker:us-east-2:123456789:endpoint/stable-diffusion-endpoint-non-finetuned-06716dbe4b2c68e7'",
}
},
operation_name="DescribeEndpoint",
),
)
(
CustomException(
message="An error occurred",
idempotence_token=idempotence_token,
original_exception=ClientError(
error_response={
"Error": {
"Code": "ValidationException",
"Message": "Cannot create already existing endpoint 'arn:aws:sagemaker:us-east-2:123456789:endpoint/stable-diffusion-endpoint-non-finetuned-06716dbe4b2c68e7'",
}
},
operation_name="DescribeEndpoint",
),
)
),
"create_endpoint_config",
),
],
)
@mock.patch(
"flytekitplugins.awssagemaker_inference.boto3_agent.Boto3AgentMixin._call",
)
async def test_agent(mock_boto_call, mock_return_value):
mock_boto_call.return_value = mock_return_value
mock_boto_call.return_value = mock_return_value[0]

agent = AgentRegistry.get_agent("boto")
task_id = Identifier(
Expand All @@ -106,7 +93,7 @@ async def test_agent(mock_boto_call, mock_return_value):
},
},
"region": "us-east-2",
"method": "create_endpoint_config",
"method": mock_return_value[1],
"images": None,
}
task_metadata = TaskMetadata(
Expand Down Expand Up @@ -149,16 +136,16 @@ async def test_agent(mock_boto_call, mock_return_value):
ctx = FlyteContext.current_context()
output_prefix = ctx.file_access.get_random_remote_directory()

if isinstance(mock_return_value, Exception):
mock_boto_call.side_effect = mock_return_value
if isinstance(mock_return_value[0], Exception):
mock_boto_call.side_effect = mock_return_value[0]

resource = await agent.do(
task_template=task_template,
inputs=task_inputs,
output_prefix=output_prefix,
)
assert resource.outputs["result"] == {
"result": f"Entity already exists: arn:aws:sagemaker:us-east-2:123456789:endpoint/stable-diffusion-endpoint-non-finetuned-06716dbe4b2c68e7"
"EndpointConfigArn": "arn:aws:sagemaker:us-east-2:123456789:endpoint/stable-diffusion-endpoint-non-finetuned-06716dbe4b2c68e7"
}
assert resource.outputs["idempotence_token"] == idempotence_token
return
Expand All @@ -169,15 +156,15 @@ async def test_agent(mock_boto_call, mock_return_value):

assert resource.phase == TaskExecution.SUCCEEDED

if mock_return_value[0]:
if mock_return_value[0][0]:
outputs = literal_map_string_repr(resource.outputs)
if "pickle_check" in mock_return_value[0]:
if "pickle_check" in mock_return_value[0][0]:
assert "pickle_file" in outputs["result"]
else:
assert (
outputs["result"]["EndpointConfigArn"]
== "arn:aws:sagemaker:us-east-2:000000000:endpoint-config/sagemaker-xgboost-endpoint-config"
)
assert outputs["idempotence_token"] == "74443947857331f7"
elif mock_return_value[0] is None:
elif mock_return_value[0][0] is None:
assert resource.outputs["result"] == {"result": None}
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,6 @@ async def test_agent(mock_boto_call, mock_return_value):
resource.outputs["result"]["EndpointArn"]
== "arn:aws:sagemaker:us-east-2:1234567890:endpoint/sagemaker-xgboost-endpoint"
)
assert resource.outputs["idempotence_token"] == idempotence_token

# DELETE
delete_response = await agent.delete(metadata)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
kwtypes(endpoint_name=str, endpoint_config_name=str),
None,
2,
2,
1,
"us-east-2",
SageMakerEndpointTask,
),
Expand Down
Loading