From acfb05d69b4d6b80d543dc9e8009a1ec7febc1fa Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Tue, 23 Jul 2024 10:44:23 +0530 Subject: [PATCH] Sagemaker dict determinism (#2597) * truncate sagemaker agent outputs Signed-off-by: Samhita Alla * fix tests and update agent output Signed-off-by: Samhita Alla * lint Signed-off-by: Samhita Alla * fix test Signed-off-by: Samhita Alla * add idempotence token to workflow Signed-off-by: Samhita Alla * fix type Signed-off-by: Samhita Alla * fix mixin Signed-off-by: Samhita Alla * modify output handler Signed-off-by: Samhita Alla * make the dictionary deterministic Signed-off-by: Samhita Alla * nit Signed-off-by: Samhita Alla --------- Signed-off-by: Samhita Alla Signed-off-by: mao3267 --- .../awssagemaker_inference/boto3_mixin.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_mixin.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_mixin.py index 05dac9de59..7d5c1e4905 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_mixin.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_mixin.py @@ -16,6 +16,16 @@ def __init__(self, message, idempotence_token, original_exception): self.original_exception = original_exception +def sorted_dict_str(d): + """Recursively convert a dictionary to a sorted string representation.""" + if isinstance(d, dict): + return "{" + ", ".join(f"{sorted_dict_str(k)}: {sorted_dict_str(v)}" for k, v in sorted(d.items())) + "}" + elif isinstance(d, list): + return "[" + ", ".join(sorted_dict_str(i) for i in sorted(d, key=lambda x: str(x))) + "]" + else: + return str(d) + + account_id_map = { "us-east-1": "785573368785", "us-east-2": "007439368137", @@ -187,7 +197,7 @@ async def _call( hash = "" if "idempotence_token" in str(updated_config): # compute hash of the config - hash = xxhash.xxh64(str(updated_config)).hexdigest() + hash = xxhash.xxh64(sorted_dict_str(updated_config)).hexdigest() updated_config = update_dict_fn(updated_config, args, idempotence_token=hash) # Asynchronous Boto3 session