diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 1ce6a05488..6656c0c293 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -360,6 +360,7 @@ def assert_type(self, expected_type: Type[DataClassJsonMixin], v: T): expected_type = get_underlying_type(expected_type) expected_fields_dict = {} + for f in dataclasses.fields(expected_type): expected_fields_dict[f.name] = f.type @@ -539,11 +540,13 @@ def _get_origin_type_in_annotation(self, python_type: Type[T]) -> Type[T]: field.type = self._get_origin_type_in_annotation(field.type) return python_type - def _fix_structured_dataset_type(self, python_type: Type[T], python_val: typing.Any) -> T: + def _fix_structured_dataset_type(self, python_type: Type[T], python_val: typing.Any) -> T | None: # In python 3.7, 3.8, DataclassJson will deserialize Annotated[StructuredDataset, kwtypes(..)] to a dict, # so here we convert it back to the Structured Dataset. from flytekit.types.structured import StructuredDataset + if python_val is None: + return python_val if python_type == StructuredDataset and type(python_val) == dict: return StructuredDataset(**python_val) elif get_origin(python_type) is list: @@ -575,9 +578,13 @@ def _make_dataclass_serializable(self, python_val: T, python_type: Type[T]) -> t return self._make_dataclass_serializable(python_val, get_args(python_type)[0]) if hasattr(python_type, "__origin__") and get_origin(python_type) is list: + if python_val is None: + return None return [self._make_dataclass_serializable(v, get_args(python_type)[0]) for v in cast(list, python_val)] if hasattr(python_type, "__origin__") and get_origin(python_type) is dict: + if python_val is None: + return None return { k: self._make_dataclass_serializable(v, get_args(python_type)[1]) for k, v in cast(dict, python_val).items() diff --git a/flytekit/interaction/click_types.py b/flytekit/interaction/click_types.py index 101ecea3d1..04a1848f84 100644 --- a/flytekit/interaction/click_types.py +++ b/flytekit/interaction/click_types.py @@ -1,3 +1,4 @@ +import dataclasses import datetime import enum import json @@ -5,7 +6,7 @@ import os import pathlib import typing -from typing import cast +from typing import cast, get_args import rich_click as click import yaml @@ -22,6 +23,7 @@ from flytekit.types.file import FlyteFile from flytekit.types.iterator.json_iterator import JSONIteratorTransformer from flytekit.types.pickle.pickle import FlytePickleTransformer +from flytekit.types.schema.types import FlyteSchema def is_pydantic_basemodel(python_type: typing.Type) -> bool: @@ -305,11 +307,50 @@ def convert( if value is None: raise click.BadParameter("None value cannot be converted to a Json type.") + FLYTE_TYPES = [FlyteFile, FlyteDirectory, StructuredDataset, FlyteSchema] + + def has_nested_dataclass(t: typing.Type) -> bool: + """ + Recursively checks whether the given type or its nested types contain any dataclass. + + This function is typically called with a dictionary or list type and will return True if + any of the nested types within the dictionary or list is a dataclass. + + Note: + - A single dataclass will return True. + - The function specifically excludes certain Flyte types like FlyteFile, FlyteDirectory, + StructuredDataset, and FlyteSchema from being considered as dataclasses. This is because + these types are handled separately by Flyte and do not need to be converted to dataclasses. + + Args: + t (typing.Type): The type to check for nested dataclasses. + + Returns: + bool: True if the type or its nested types contain a dataclass, False otherwise. + """ + + if dataclasses.is_dataclass(t): + # FlyteTypes is not supported now, we can support it in the future. + return t not in FLYTE_TYPES + + return any(has_nested_dataclass(arg) for arg in get_args(t)) + parsed_value = self._parse(value, param) # We compare the origin type because the json parsed value for list or dict is always a list or dict without # the covariant type information. if type(parsed_value) == typing.get_origin(self._python_type) or type(parsed_value) == self._python_type: + # Indexing the return value of get_args will raise an error for native dict and list types. + # We don't support native list/dict types with nested dataclasses. + if get_args(self._python_type) == (): + return parsed_value + elif isinstance(parsed_value, list) and has_nested_dataclass(get_args(self._python_type)[0]): + j = JsonParamType(get_args(self._python_type)[0]) + return [j.convert(v, param, ctx) for v in parsed_value] + elif isinstance(parsed_value, dict) and has_nested_dataclass(get_args(self._python_type)[1]): + j = JsonParamType(get_args(self._python_type)[1]) + return {k: j.convert(v, param, ctx) for k, v in parsed_value.items()} + return parsed_value if is_pydantic_basemodel(self._python_type): diff --git a/tests/flytekit/unit/cli/pyflyte/my_wf_input.json b/tests/flytekit/unit/cli/pyflyte/my_wf_input.json index c20081f3b2..4c596e4d55 100644 --- a/tests/flytekit/unit/cli/pyflyte/my_wf_input.json +++ b/tests/flytekit/unit/cli/pyflyte/my_wf_input.json @@ -42,6 +42,9 @@ }, "p": "None", "q": "tests/flytekit/unit/cli/pyflyte/testdata", + "r": [{"i": 1, "a": ["h", "e"]}], + "s": {"x": {"i": 1, "a": ["h", "e"]}}, + "t": {"i": [{"i":1,"a":["h","e"]}]}, "remote": "tests/flytekit/unit/cli/pyflyte/testdata", "image": "tests/flytekit/unit/cli/pyflyte/testdata" } diff --git a/tests/flytekit/unit/cli/pyflyte/my_wf_input.yaml b/tests/flytekit/unit/cli/pyflyte/my_wf_input.yaml index 678f5331c8..5f15826b80 100644 --- a/tests/flytekit/unit/cli/pyflyte/my_wf_input.yaml +++ b/tests/flytekit/unit/cli/pyflyte/my_wf_input.yaml @@ -30,5 +30,22 @@ o: - tests/flytekit/unit/cli/pyflyte/testdata/df.parquet p: 'None' q: tests/flytekit/unit/cli/pyflyte/testdata +r: + - i: 1 + a: + - h + - e +s: + x: + i: 1 + a: + - h + - e +t: + i: + - i: 1 + a: + - h + - e remote: tests/flytekit/unit/cli/pyflyte/testdata image: tests/flytekit/unit/cli/pyflyte/testdata diff --git a/tests/flytekit/unit/cli/pyflyte/test_run.py b/tests/flytekit/unit/cli/pyflyte/test_run.py index 475fb42ff1..58c4518f3d 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_run.py +++ b/tests/flytekit/unit/cli/pyflyte/test_run.py @@ -201,6 +201,12 @@ def test_pyflyte_run_cli(workflow_file): "Any", "--q", DIR_NAME, + "--r", + json.dumps([{"i": 1, "a": ["h", "e"]}]), + "--s", + json.dumps({"x": {"i": 1, "a": ["h", "e"]}}), + "--t", + json.dumps({"i": [{"i":1,"a":["h","e"]}]}), ], catch_exceptions=False, ) diff --git a/tests/flytekit/unit/cli/pyflyte/workflow.py b/tests/flytekit/unit/cli/pyflyte/workflow.py index accebf82df..104538c338 100644 --- a/tests/flytekit/unit/cli/pyflyte/workflow.py +++ b/tests/flytekit/unit/cli/pyflyte/workflow.py @@ -35,6 +35,9 @@ class MyDataclass(DataClassJsonMixin): i: int a: typing.List[str] +@dataclass +class NestedDataclass(DataClassJsonMixin): + i: typing.List[MyDataclass] class Color(enum.Enum): RED = "RED" @@ -61,8 +64,11 @@ def print_all( o: typing.Dict[str, typing.List[FlyteFile]], p: typing.Any, q: FlyteDirectory, + r: typing.List[MyDataclass], + s: typing.Dict[str, MyDataclass], + t: NestedDataclass, ): - print(f"{a}, {b}, {c}, {d}, {e}, {f}, {g}, {h}, {i}, {j}, {k}, {l}, {m}, {n}, {o}, {p}, {q}") + print(f"{a}, {b}, {c}, {d}, {e}, {f}, {g}, {h}, {i}, {j}, {k}, {l}, {m}, {n}, {o}, {p}, {q}, {r}, {s}, {t}") @task @@ -93,6 +99,9 @@ def my_wf( o: typing.Dict[str, typing.List[FlyteFile]], p: typing.Any, q: FlyteDirectory, + r: typing.List[MyDataclass], + s: typing.Dict[str, MyDataclass], + t: NestedDataclass, remote: pd.DataFrame, image: StructuredDataset, m: dict = {"hello": "world"}, @@ -100,7 +109,7 @@ def my_wf( x = get_subset_df(df=remote) # noqa: shown for demonstration; users should use the same types between tasks show_sd(in_sd=x) show_sd(in_sd=image) - print_all(a=a, b=b, c=c, d=d, e=e, f=f, g=g, h=h, i=i, j=j, k=k, l=l, m=m, n=n, o=o, p=p, q=q) + print_all(a=a, b=b, c=c, d=d, e=e, f=f, g=g, h=h, i=i, j=j, k=k, l=l, m=m, n=n, o=o, p=p, q=q, r=r, s=s, t=t) return x diff --git a/tests/flytekit/unit/interaction/test_click_types.py b/tests/flytekit/unit/interaction/test_click_types.py index a9ccfe61b3..11cfb374d8 100644 --- a/tests/flytekit/unit/interaction/test_click_types.py +++ b/tests/flytekit/unit/interaction/test_click_types.py @@ -1,3 +1,4 @@ +from dataclasses import field import json import tempfile import typing @@ -270,3 +271,230 @@ class Datum: assert v.y == "2" assert v.z == {1: "one", 2: "two"} assert v.w == [1, 2, 3] + + +def test_nested_dataclass_type(): + from dataclasses import dataclass + + @dataclass + class Datum: + w: int + x: str = "default" + y: typing.Dict[str, str] = field(default_factory=lambda: {"key": "value"}) + z: typing.List[int] = field(default_factory=lambda: [1, 2, 3]) + + @dataclass + class NestedDatum: + w: Datum + x: typing.List[Datum] + y: typing.Dict[str, Datum] = field(default_factory=lambda: {"key": Datum(1)}) + + + # typing.List[Datum] + value = '[{ "w": 1 }]' + t = JsonParamType(typing.List[Datum]) + v = t.convert(value=value, param=None, ctx=None) + + ctx = FlyteContextManager.current_context() + lt = TypeEngine.to_literal_type(typing.List[Datum]) + literal_converter = FlyteLiteralConverter( + ctx, literal_type=lt, python_type=typing.List[Datum], is_remote=False + ) + v = literal_converter.convert(ctx, None, v) + + assert v[0].w == 1 + assert v[0].x == "default" + assert v[0].y == {"key": "value"} + assert v[0].z == [1, 2, 3] + + # typing.Dict[str, Datum] + value = '{ "x": { "w": 1 } }' + t = JsonParamType(typing.Dict[str, Datum]) + v = t.convert(value=value, param=None, ctx=None) + ctx = FlyteContextManager.current_context() + lt = TypeEngine.to_literal_type(typing.Dict[str, Datum]) + literal_converter = FlyteLiteralConverter( + ctx, literal_type=lt, python_type=typing.Dict[str, Datum], is_remote=False + ) + v = literal_converter.convert(ctx, None, v) + + assert v["x"].w == 1 + assert v["x"].x == "default" + assert v["x"].y == {"key": "value"} + assert v["x"].z == [1, 2, 3] + + # typing.List[NestedDatum] + value = '[{"w":{ "w" : 1 },"x":[{ "w" : 1 }]}]' + t = JsonParamType(typing.List[NestedDatum]) + v = t.convert(value=value, param=None, ctx=None) + ctx = FlyteContextManager.current_context() + lt = TypeEngine.to_literal_type(typing.List[NestedDatum]) + literal_converter = FlyteLiteralConverter( + ctx, literal_type=lt, python_type=typing.List[NestedDatum], is_remote=False + ) + v = literal_converter.convert(ctx, None, v) + + assert v[0].w.w == 1 + assert v[0].w.x == "default" + assert v[0].w.y == {"key": "value"} + assert v[0].w.z == [1, 2, 3] + assert v[0].x[0].w == 1 + assert v[0].x[0].x == "default" + assert v[0].x[0].y == {"key": "value"} + assert v[0].x[0].z == [1, 2, 3] + + # typing.List[typing.List[Datum]] + value = '[[{ "w": 1 }]]' + t = JsonParamType(typing.List[typing.List[Datum]]) + v = t.convert(value=value, param=None, ctx=None) + ctx = FlyteContextManager.current_context() + lt = TypeEngine.to_literal_type(typing.List[typing.List[Datum]]) + literal_converter = FlyteLiteralConverter( + ctx, literal_type=lt, python_type=typing.List[typing.List[Datum]], is_remote=False + ) + v = literal_converter.convert(ctx, None, v) + + assert v[0][0].w == 1 + assert v[0][0].x == "default" + assert v[0][0].y == {"key": "value"} + assert v[0][0].z == [1, 2, 3] + +def test_dataclass_with_default_none(): + from dataclasses import dataclass + + @dataclass + class Datum: + x: int + y: str = None + z: typing.Dict[int, str] = None + w: typing.List[int] = None + + t = JsonParamType(Datum) + value = '{ "x": 1 }' + v = t.convert(value=value, param=None, ctx=None) + lt = TypeEngine.to_literal_type(Datum) + ctx = FlyteContextManager.current_context() + literal_converter = FlyteLiteralConverter( + ctx, literal_type=lt, python_type=Datum, is_remote=False + ) + v = literal_converter.convert(ctx=ctx, param=None, value=v) + + assert v.x == 1 + assert v.y is None + assert v.z is None + assert v.w is None + + +def test_dataclass_with_flyte_type_exception(): + from dataclasses import dataclass + from flytekit import StructuredDataset + from flytekit.types.directory import FlyteDirectory + from flytekit.types.file import FlyteFile + import os + + DIR_NAME = os.path.dirname(os.path.realpath(__file__)) + parquet_file = os.path.join(DIR_NAME, "testdata/df.parquet") + + @dataclass + class Datum: + x: FlyteFile + y: FlyteDirectory + z: StructuredDataset + + t = JsonParamType(Datum) + value = { "x": parquet_file, "y": DIR_NAME, "z": os.path.join(DIR_NAME, "testdata")} + + with pytest.raises(AttributeError): + t.convert(value=value, param=None, ctx=None) + +def test_dataclass_with_optional_fields(): + from dataclasses import dataclass + from typing import Optional + + @dataclass + class Datum: + x: int + y: Optional[str] = None + z: Optional[typing.Dict[int, str]] = None + w: Optional[typing.List[int]] = None + + t = JsonParamType(Datum) + value = '{ "x": 1 }' + v = t.convert(value=value, param=None, ctx=None) + lt = TypeEngine.to_literal_type(Datum) + ctx = FlyteContextManager.current_context() + literal_converter = FlyteLiteralConverter( + ctx, literal_type=lt, python_type=Datum, is_remote=False + ) + v = literal_converter.convert(ctx=ctx, param=None, value=v) + + # Assertions to check the Optional fields + assert v.x == 1 + assert v.y is None # Optional field with no value provided + assert v.z is None # Optional field with no value provided + assert v.w is None # Optional field with no value provided + + # Test with all fields provided + value = '{ "x": 2, "y": "test", "z": {"1": "value"}, "w": [1, 2, 3] }' + v = t.convert(value=value, param=None, ctx=None) + v = literal_converter.convert(ctx=ctx, param=None, value=v) + + assert v.x == 2 + assert v.y == "test" + assert v.z == {1: "value"} + assert v.w == [1, 2, 3] + +def test_nested_dataclass_with_optional_fields(): + from dataclasses import dataclass + from typing import Optional, List, Dict + + @dataclass + class InnerDatum: + a: int + b: Optional[str] = None + + @dataclass + class Datum: + x: int + y: Optional[InnerDatum] = None + z: Optional[Dict[str, InnerDatum]] = None + w: Optional[List[InnerDatum]] = None + + t = JsonParamType(Datum) + + # Case 1: Only required field provided + value = '{ "x": 1 }' + v = t.convert(value=value, param=None, ctx=None) + lt = TypeEngine.to_literal_type(Datum) + ctx = FlyteContextManager.current_context() + literal_converter = FlyteLiteralConverter( + ctx, literal_type=lt, python_type=Datum, is_remote=False + ) + v = literal_converter.convert(ctx=ctx, param=None, value=v) + + # Assertions to check the Optional fields + assert v.x == 1 + assert v.y is None # Optional field with no value provided + assert v.z is None # Optional field with no value provided + assert v.w is None # Optional field with no value provided + + # Case 2: All fields provided with nested structures + value = ''' + { + "x": 2, + "y": {"a": 10, "b": "inner"}, + "z": {"key": {"a": 20, "b": "nested"}}, + "w": [{"a": 30, "b": "list_item"}] + } + ''' + v = t.convert(value=value, param=None, ctx=None) + v = literal_converter.convert(ctx=ctx, param=None, value=v) + + # Assertions for nested structure + assert v.x == 2 + assert v.y.a == 10 + assert v.y.b == "inner" + assert v.z["key"].a == 20 + assert v.z["key"].b == "nested" + assert v.w[0].a == 30 + assert v.w[0].b == "list_item"