diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index bd617a161aa..3165b4cdf57 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -24,7 +24,6 @@ from google.protobuf.json_format import ParseDict as _ParseDict from google.protobuf.message import Message from google.protobuf.struct_pb2 import Struct -from marshmallow_enum import EnumField, LoadDumpOptions from mashumaro.codecs.json import JSONDecoder, JSONEncoder from mashumaro.mixins.json import DataClassJSONMixin from typing_extensions import Annotated, get_args, get_origin @@ -425,6 +424,7 @@ def get_literal_type(self, t: Type[T]) -> LiteralType: Extracts the Literal type definition for a Dataclass and returns a type Struct. If possible also extracts the JSONSchema for the dataclass. """ + if is_annotated(t): args = get_args(t) for x in args[1:]: @@ -439,6 +439,8 @@ def get_literal_type(self, t: Type[T]) -> LiteralType: schema = None try: + from marshmallow_enum import EnumField, LoadDumpOptions + if issubclass(t, DataClassJsonMixin): s = cast(DataClassJsonMixin, self._get_origin_type_in_annotation(t)).schema() for _, v in s.fields.items(): @@ -450,10 +452,6 @@ def get_literal_type(self, t: Type[T]) -> LiteralType: from marshmallow_jsonschema import JSONSchema schema = JSONSchema().dump(s) - else: # DataClassJSONMixin - from mashumaro.jsonschema import build_json_schema - - schema = build_json_schema(cast(DataClassJSONMixin, self._get_origin_type_in_annotation(t))).to_dict() except Exception as e: # https://github.com/lovasoa/marshmallow_dataclass/issues/13 logger.warning( @@ -462,6 +460,17 @@ def get_literal_type(self, t: Type[T]) -> LiteralType: f"evaluation doesn't work with json dataclasses" ) + if schema is None: + try: + from mashumaro.jsonschema import build_json_schema + + schema = build_json_schema(cast(DataClassJSONMixin, self._get_origin_type_in_annotation(t))).to_dict() + except Exception as e: + logger.error( + f"Failed to extract schema for object {t}, error: {e}\n" + f"Please remove `DataClassJsonMixin` and `dataclass_json` decorator from the dataclass definition" + ) + # Recursively construct the dataclass_type which contains the literal type of each field literal_type = {} diff --git a/flytekit/types/directory/types.py b/flytekit/types/directory/types.py index dc294134a1b..eb01cdd0394 100644 --- a/flytekit/types/directory/types.py +++ b/flytekit/types/directory/types.py @@ -123,7 +123,7 @@ def t1(in1: FlyteDirectory["svg"]): def _serialize(self) -> typing.Dict[str, str]: lv = FlyteDirToMultipartBlobTransformer().to_literal( - FlyteContextManager.current_context(), self, FlyteDirectory, None + FlyteContextManager.current_context(), self, type(self), None ) return {"path": lv.scalar.blob.uri} diff --git a/flytekit/types/file/file.py b/flytekit/types/file/file.py index cc7ba66bed6..e703f71ccd7 100644 --- a/flytekit/types/file/file.py +++ b/flytekit/types/file/file.py @@ -147,7 +147,7 @@ def t2() -> flytekit_typing.FlyteFile["csv"]: """ def _serialize(self) -> typing.Dict[str, str]: - lv = FlyteFilePathTransformer().to_literal(FlyteContextManager.current_context(), self, FlyteFile, None) + lv = FlyteFilePathTransformer().to_literal(FlyteContextManager.current_context(), self, type(self), None) return {"path": lv.scalar.blob.uri} @classmethod diff --git a/flytekit/types/schema/types.py b/flytekit/types/schema/types.py index cbfbc9eb89d..88adad2681c 100644 --- a/flytekit/types/schema/types.py +++ b/flytekit/types/schema/types.py @@ -185,6 +185,7 @@ class FlyteSchema(SerializableType, DataClassJSONMixin): """ def _serialize(self) -> typing.Dict[str, typing.Optional[str]]: + FlyteSchemaTransformer().to_literal(FlyteContextManager.current_context(), self, type(self), None) return {"remote_path": self.remote_path} @classmethod diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index 56f42a41604..c11519462ee 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -57,7 +57,7 @@ class (that is just a model, a Python class representation of the protobuf). def _serialize(self) -> Dict[str, Optional[str]]: lv = StructuredDatasetTransformerEngine().to_literal( - FlyteContextManager.current_context(), self, StructuredDataset, None + FlyteContextManager.current_context(), self, type(self), None ) sd = StructuredDataset(uri=lv.scalar.structured_dataset.uri) sd.file_format = lv.scalar.structured_dataset.metadata.structured_dataset_type.format diff --git a/tests/flytekit/unit/core/test_dataclass.py b/tests/flytekit/unit/core/test_dataclass.py index ee72f1cc846..f07f51f7aef 100644 --- a/tests/flytekit/unit/core/test_dataclass.py +++ b/tests/flytekit/unit/core/test_dataclass.py @@ -5,7 +5,7 @@ import tempfile from dataclasses import dataclass from typing import Annotated, List, Dict, Optional - +from flytekit.types.schema import FlyteSchema from flytekit.core.type_engine import TypeEngine from flytekit.core.context_manager import FlyteContextManager from flytekit.core.type_engine import DataclassTransformer @@ -857,3 +857,28 @@ class NestedFlyteTypes(DataClassJSONMixin): pv = TypeEngine.to_python_value(ctx, lv, NestedFlyteTypes) assert isinstance(pv, NestedFlyteTypes) DataclassTransformer().assert_type(NestedFlyteTypes, pv) + +def test_get_literal_type_data_class_json_fail_but_mashumaro_works(): + @dataclass + class FlyteTypesWithDataClassJson(DataClassJsonMixin): + flytefile: FlyteFile + flytedir: FlyteDirectory + structured_dataset: StructuredDataset + fs: FlyteSchema + + @dataclass + class NestedFlyteTypesWithDataClassJson(DataClassJsonMixin): + flytefile: FlyteFile + flytedir: FlyteDirectory + structured_dataset: StructuredDataset + flyte_types: FlyteTypesWithDataClassJson + fs: FlyteSchema + flyte_types: FlyteTypesWithDataClassJson + list_flyte_types: List[FlyteTypesWithDataClassJson] + dict_flyte_types: Dict[str, FlyteTypesWithDataClassJson] + flyte_types: FlyteTypesWithDataClassJson + optional_flyte_types: Optional[FlyteTypesWithDataClassJson] = None + + transformer = DataclassTransformer() + lt = transformer.get_literal_type(NestedFlyteTypesWithDataClassJson) + assert lt.metadata is not None diff --git a/tests/flytekit/unit/core/test_type_delayed.py b/tests/flytekit/unit/core/test_type_delayed.py index a47a0b88f84..f35792b8205 100644 --- a/tests/flytekit/unit/core/test_type_delayed.py +++ b/tests/flytekit/unit/core/test_type_delayed.py @@ -1,5 +1,6 @@ from __future__ import annotations +import dataclasses import typing from dataclasses import dataclass @@ -21,13 +22,7 @@ class Foo(DataClassJsonMixin): def test_jsondc_schemaize(): lt = TypeEngine.to_literal_type(Foo) pt = TypeEngine.guess_python_type(lt) - - # When postponed annotations are enabled, dataclass_json will not work and we'll end up with a - # schemaless generic. - # This test basically tests the broken behavior. Remove this test if - # https://github.com/lovasoa/marshmallow_dataclass/issues/13 is ever fixed. - assert pt is dict - + assert dataclasses.is_dataclass(pt) def test_structured_dataset(): ctx = context_manager.FlyteContext.current_context()