diff --git a/flytekit/interaction/click_types.py b/flytekit/interaction/click_types.py index 2c26ed0cbc..c16339a236 100644 --- a/flytekit/interaction/click_types.py +++ b/flytekit/interaction/click_types.py @@ -10,7 +10,7 @@ import cloudpickle import rich_click as click import yaml -from dataclasses_json import DataClassJsonMixin +from dataclasses_json import DataClassJsonMixin, dataclass_json from pytimeparse import parse from flytekit import BlobType, FlyteContext, FlyteContextManager, Literal, LiteralType, StructuredDataset @@ -273,6 +273,11 @@ def convert( if is_pydantic_basemodel(self._python_type): return self._python_type.parse_raw(json.dumps(parsed_value)) # type: ignore + + # Ensure that the python type has `from_json` function + if not hasattr(self._python_type, "from_json"): + self._python_type = dataclass_json(self._python_type) + return cast(DataClassJsonMixin, self._python_type).from_json(json.dumps(parsed_value)) diff --git a/tests/flytekit/unit/interaction/test_click_types.py b/tests/flytekit/unit/interaction/test_click_types.py index 83a191c449..cb32982916 100644 --- a/tests/flytekit/unit/interaction/test_click_types.py +++ b/tests/flytekit/unit/interaction/test_click_types.py @@ -206,3 +206,23 @@ def test_query_passing(param_type: click.ParamType): query = a.query() assert param_type.convert(value=query, param=None, ctx=None) is query + + +def test_dataclass_type(): + from dataclasses import dataclass + + @dataclass + class Datum: + x: int + y: str + z: dict[int, str] + w: list[int] + + t = JsonParamType(Datum) + value = '{ "x": 1, "y": "2", "z": { "1": "one", "2": "two" }, "w": [1, 2, 3] }' + v = t.convert(value=value, param=None, ctx=None) + + assert v.x == 1 + assert v.y == "2" + assert v.z == {1: "one", 2: "two"} + assert v.w == [1, 2, 3]