Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pickle dictionary when it isn't JSON serializable #2390

Merged
merged 16 commits into from
May 9, 2024
8 changes: 6 additions & 2 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,11 @@ def resolve_attr_path_in_promise(p: Promise) -> Promise:
break

# If the current value is a dataclass, resolve the dataclass with the remaining path
if type(curr_val.value) is _literals_models.Scalar and type(curr_val.value.value) is _struct.Struct:
if (
len(p.attr_path) > 0
and type(curr_val.value) is _literals_models.Scalar
and type(curr_val.value.value) is _struct.Struct
):
st = curr_val.value.value
new_st = resolve_attr_path_in_pb_struct(st, attr_path=p.attr_path[used:])
literal_type = TypeEngine.to_literal_type(type(new_st))
Expand Down Expand Up @@ -729,7 +733,7 @@ def binding_data_from_python_std(
lit = TypeEngine.to_literal(ctx, t_value, type(t_value), expected_literal_type)
return _literals_models.BindingData(scalar=lit.scalar)
else:
_, v_type = DictTransformer.get_dict_types(t_value_type)
_, v_type = DictTransformer.extract_types_or_metadata(t_value_type)
m = _literals_models.BindingDataMap(
bindings={
k: binding_data_from_python_std(
Expand Down
89 changes: 72 additions & 17 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import textwrap
import typing
from abc import ABC, abstractmethod
from collections import OrderedDict
from functools import lru_cache
from typing import Dict, List, NamedTuple, Optional, Type, cast

Expand Down Expand Up @@ -713,7 +714,7 @@ def _fix_val_int(self, t: typing.Type, val: typing.Any) -> typing.Any:
return list(map(lambda x: self._fix_val_int(ListTransformer.get_sub_type(t), x), val))

if isinstance(val, dict):
ktype, vtype = DictTransformer.get_dict_types(t)
ktype, vtype = DictTransformer.extract_types_or_metadata(t)
# Handle nested Dict. e.g. {1: {2: 3}, 4: {5: 6}})
return {
self._fix_val_int(cast(type, ktype), k): self._fix_val_int(cast(type, vtype), v) for k, v in val.items()
Expand Down Expand Up @@ -1660,13 +1661,10 @@ class DictTransformer(TypeTransformer[dict]):
"""

def __init__(self):
super().__init__("Typed Dict", dict)
super().__init__("Python Dictionary", dict)

@staticmethod
def get_dict_types(t: Optional[Type[dict]]) -> typing.Tuple[Optional[type], Optional[type]]:
"""
Return the generic Type T of the Dict
"""
def extract_types_or_metadata(t: Optional[Type[dict]]) -> typing.Tuple:
_origin = get_origin(t)
_args = get_args(t)
if _origin is not None:
Expand All @@ -1679,22 +1677,60 @@ def get_dict_types(t: Optional[Type[dict]]) -> typing.Tuple[Optional[type], Opti
raise ValueError(
f"Flytekit does not currently have support for FlyteAnnotations applied to dicts. {t} cannot be parsed."
)
if _origin is dict and _args is not None:
if _origin in [dict, Annotated] and _args is not None:
return _args # type: ignore
return None, None

@staticmethod
def dict_to_generic_literal(v: dict) -> Literal:
def dict_to_generic_literal(v: dict, allow_pickle: bool) -> Literal:
"""
Creates a flyte-specific ``Literal`` value from a native python dictionary.
"""
return Literal(scalar=Scalar(generic=_json_format.Parse(json.dumps(v), _struct.Struct())))
from flytekit.types.pickle import FlytePickle

try:
return Literal(
scalar=Scalar(generic=_json_format.Parse(json.dumps(v), _struct.Struct())),
metadata={"format": "json"},
)
except TypeError as e:
if allow_pickle:
remote_path = FlytePickle.to_pickle(v)
return Literal(
scalar=Scalar(
generic=_json_format.Parse(json.dumps({"pickle_file": remote_path}), _struct.Struct())
),
metadata={"format": "pickle"},
)
raise e

@staticmethod
def is_pickle(python_type: Type[dict]) -> typing.Tuple[bool, Type]:
base_type, *metadata = DictTransformer.extract_types_or_metadata(python_type)

for each_metadata in metadata:
if isinstance(each_metadata, OrderedDict):
allow_pickle = each_metadata.get("allow_pickle", False)
return allow_pickle, base_type

return False, base_type

@staticmethod
def dict_types(python_type: Type) -> typing.Tuple[typing.Any, ...]:
if get_origin(python_type) is Annotated:
base_type, *_ = DictTransformer.extract_types_or_metadata(python_type)
tp = get_args(base_type)
else:
tp = DictTransformer.extract_types_or_metadata(python_type)

return tp

def get_literal_type(self, t: Type[dict]) -> LiteralType:
"""
Transforms a native python dictionary to a flyte-specific ``LiteralType``
"""
tp = self.get_dict_types(t)
tp = self.dict_types(t)

if tp:
if tp[0] == str:
try:
Expand All @@ -1710,21 +1746,33 @@ def to_literal(
if type(python_val) != dict:
raise TypeTransformerFailedError("Expected a dict")

allow_pickle = False
base_type = None

if get_origin(python_type) is Annotated:
allow_pickle, base_type = DictTransformer.is_pickle(python_type)

if expected and expected.simple and expected.simple == SimpleType.STRUCT:
return self.dict_to_generic_literal(python_val)
return self.dict_to_generic_literal(python_val, allow_pickle)

lit_map = {}
for k, v in python_val.items():
if type(k) != str:
raise ValueError("Flyte MapType expects all keys to be strings")
# TODO: log a warning for Annotated objects that contain HashMethod
k_type, v_type = self.get_dict_types(python_type)

if base_type:
_, v_type = get_args(base_type)
else:
_, v_type = self.extract_types_or_metadata(python_type)

lit_map[k] = TypeEngine.to_literal(ctx, v, cast(type, v_type), expected.map_value_type)
return Literal(map=LiteralMap(literals=lit_map))

def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[dict]) -> dict:
if lv and lv.map and lv.map.literals is not None:
tp = self.get_dict_types(expected_python_type)
tp = self.dict_types(expected_python_type)

if tp is None or tp[0] is None:
raise TypeError(
"TypeMismatch: Cannot convert to python dictionary from Flyte Literal Dictionary as the given "
Expand All @@ -1741,10 +1789,17 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
# for empty generic we have to explicitly test for lv.scalar.generic is not None as empty dict
# evaluates to false
if lv and lv.scalar and lv.scalar.generic is not None:
try:
return json.loads(_json_format.MessageToJson(lv.scalar.generic))
except TypeError:
raise TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}")
if lv.metadata["format"] == "json":
try:
return json.loads(_json_format.MessageToJson(lv.scalar.generic))
except TypeError:
raise TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}")
elif lv.metadata["format"] == "pickle":
from flytekit.types.pickle import FlytePickle

uri = json.loads(_json_format.MessageToJson(lv.scalar.generic)).get("pickle_file")
return FlytePickle.from_pickle(uri)

raise TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}")

def guess_python_type(self, literal_type: LiteralType) -> Union[Type[dict], typing.Dict[Type, Type]]:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import Optional

from flyteidl.core.execution_pb2 import TaskExecution
from typing_extensions import Annotated

from flytekit import FlyteContextManager, kwtypes
from flytekit.core.type_engine import TypeEngine
from flytekit.extend.backend.base_agent import (
AgentRegistry,
Resource,
Expand Down Expand Up @@ -54,9 +57,19 @@ async def do(self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = N
inputs=inputs,
)

outputs = None
outputs = {"result": {"result": None}}
samhita-alla marked this conversation as resolved.
Show resolved Hide resolved
if result:
outputs = {"result": result}
ctx = FlyteContextManager.current_context()
outputs = LiteralMap(
literals={
"result": TypeEngine.to_literal(
ctx,
result,
Annotated[dict, kwtypes(allow_pickle=True)],
TypeEngine.to_literal_type(dict),
)
}
)

return Resource(phase=TaskExecution.SUCCEEDED, outputs=outputs)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ def __init__(
name=name,
task_config=task_config,
task_type=self._TASK_TYPE,
interface=Interface(inputs=inputs, outputs=kwtypes(result=dict)),
interface=Interface(
inputs=inputs,
outputs=kwtypes(result=dict),
),
**kwargs,
)

Expand Down
77 changes: 56 additions & 21 deletions plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,63 @@
from datetime import timedelta
from datetime import datetime, timedelta
from unittest import mock

import pytest
from flyteidl.core.execution_pb2 import TaskExecution

from flytekit.extend.backend.base_agent import AgentRegistry
from flytekit.interaction.string_literals import literal_map_string_repr
from flytekit.interfaces.cli_identifiers import Identifier
from flytekit.models import literals
from flytekit.models.core.identifier import ResourceType
from flytekit.models.task import RuntimeMetadata, TaskMetadata, TaskTemplate


@pytest.mark.asyncio
@pytest.mark.parametrize(
"mock_return_value",
[
(
{
"ResponseMetadata": {
"RequestId": "66f80391-348a-4ee0-9158-508914d16db2",
"HTTPStatusCode": 200.0,
"RetryAttempts": 0.0,
"HTTPHeaders": {
"content-type": "application/x-amz-json-1.1",
"date": "Wed, 31 Jan 2024 16:43:52 GMT",
"x-amzn-requestid": "66f80391-348a-4ee0-9158-508914d16db2",
"content-length": "114",
},
},
"EndpointConfigArn": "arn:aws:sagemaker:us-east-2:000000000:endpoint-config/sagemaker-xgboost-endpoint-config",
}
),
(
{
"ResponseMetadata": {
"RequestId": "66f80391-348a-4ee0-9158-508914d16db2",
"HTTPStatusCode": 200.0,
"RetryAttempts": 0.0,
"HTTPHeaders": {
"content-type": "application/x-amz-json-1.1",
"date": "Wed, 31 Jan 2024 16:43:52 GMT",
"x-amzn-requestid": "66f80391-348a-4ee0-9158-508914d16db2",
"content-length": "114",
},
},
"pickle_check": datetime(2024, 5, 5),
"EndpointConfigArn": "arn:aws:sagemaker:us-east-2:000000000:endpoint-config/sagemaker-xgboost-endpoint-config",
}
),
(None),
],
)
@mock.patch(
"flytekitplugins.awssagemaker_inference.boto3_agent.Boto3AgentMixin._call",
return_value={
"ResponseMetadata": {
"RequestId": "66f80391-348a-4ee0-9158-508914d16db2",
"HTTPStatusCode": 200.0,
"RetryAttempts": 0.0,
"HTTPHeaders": {
"content-type": "application/x-amz-json-1.1",
"date": "Wed, 31 Jan 2024 16:43:52 GMT",
"x-amzn-requestid": "66f80391-348a-4ee0-9158-508914d16db2",
"content-length": "114",
},
},
"EndpointConfigArn": "arn:aws:sagemaker:us-east-2:000000000:endpoint-config/sagemaker-xgboost-endpoint-config",
},
)
async def test_agent(mock_boto_call):
async def test_agent(mock_boto_call, mock_return_value):
mock_boto_call.return_value = mock_return_value

agent = AgentRegistry.get_agent("boto")
task_id = Identifier(
resource_type=ResourceType.TASK,
Expand Down Expand Up @@ -88,9 +116,16 @@ async def test_agent(mock_boto_call):
)

resource = await agent.do(task_template, task_inputs)

assert resource.phase == TaskExecution.SUCCEEDED
assert (
resource.outputs["result"]["EndpointConfigArn"]
== "arn:aws:sagemaker:us-east-2:000000000:endpoint-config/sagemaker-xgboost-endpoint-config"
)

if mock_return_value:
outputs = literal_map_string_repr(resource.outputs)
if "pickle_check" in mock_return_value:
assert "pickle_file" in outputs["result"]
else:
assert (
outputs["result"]["EndpointConfigArn"]
== "arn:aws:sagemaker:us-east-2:000000000:endpoint-config/sagemaker-xgboost-endpoint-config"
)
elif mock_return_value is None:
assert resource.outputs["result"] == {"result": None}
2 changes: 1 addition & 1 deletion tests/flytekit/unit/core/test_local_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ def test_stable_cache_key():
}
)
key = _calculate_cache_key("task_name_1", "31415", lm)
assert key == "task_name_1-31415-404b45f8556276183621d4bf37f50049"
assert key == "task_name_1-31415-189e755a8f41c006889c291fcaedb4eb"


@pytest.mark.skipif("pandas" not in sys.modules, reason="Pandas is not installed.")
Expand Down
Loading
Loading