Skip to content

Commit

Permalink
Fix DataClass Json Schema Error for get literal type method (flyteo…
Browse files Browse the repository at this point in the history
…rg#2587)

Signed-off-by: Future-Outlier <[email protected]>
  • Loading branch information
Future-Outlier authored and Mecoli1219 committed Jul 27, 2024
1 parent 8a45055 commit 2e28b5a
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 16 deletions.
19 changes: 14 additions & 5 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from google.protobuf.json_format import ParseDict as _ParseDict
from google.protobuf.message import Message
from google.protobuf.struct_pb2 import Struct
from marshmallow_enum import EnumField, LoadDumpOptions
from mashumaro.codecs.json import JSONDecoder, JSONEncoder
from mashumaro.mixins.json import DataClassJSONMixin
from typing_extensions import Annotated, get_args, get_origin
Expand Down Expand Up @@ -425,6 +424,7 @@ def get_literal_type(self, t: Type[T]) -> LiteralType:
Extracts the Literal type definition for a Dataclass and returns a type Struct.
If possible also extracts the JSONSchema for the dataclass.
"""

if is_annotated(t):
args = get_args(t)
for x in args[1:]:
Expand All @@ -439,6 +439,8 @@ def get_literal_type(self, t: Type[T]) -> LiteralType:

schema = None
try:
from marshmallow_enum import EnumField, LoadDumpOptions

if issubclass(t, DataClassJsonMixin):
s = cast(DataClassJsonMixin, self._get_origin_type_in_annotation(t)).schema()
for _, v in s.fields.items():
Expand All @@ -450,10 +452,6 @@ def get_literal_type(self, t: Type[T]) -> LiteralType:
from marshmallow_jsonschema import JSONSchema

schema = JSONSchema().dump(s)
else: # DataClassJSONMixin
from mashumaro.jsonschema import build_json_schema

schema = build_json_schema(cast(DataClassJSONMixin, self._get_origin_type_in_annotation(t))).to_dict()
except Exception as e:
# https://github.com/lovasoa/marshmallow_dataclass/issues/13
logger.warning(
Expand All @@ -462,6 +460,17 @@ def get_literal_type(self, t: Type[T]) -> LiteralType:
f"evaluation doesn't work with json dataclasses"
)

if schema is None:
try:
from mashumaro.jsonschema import build_json_schema

schema = build_json_schema(cast(DataClassJSONMixin, self._get_origin_type_in_annotation(t))).to_dict()
except Exception as e:
logger.error(
f"Failed to extract schema for object {t}, error: {e}\n"
f"Please remove `DataClassJsonMixin` and `dataclass_json` decorator from the dataclass definition"
)

# Recursively construct the dataclass_type which contains the literal type of each field
literal_type = {}

Expand Down
2 changes: 1 addition & 1 deletion flytekit/types/directory/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def t1(in1: FlyteDirectory["svg"]):

def _serialize(self) -> typing.Dict[str, str]:
lv = FlyteDirToMultipartBlobTransformer().to_literal(
FlyteContextManager.current_context(), self, FlyteDirectory, None
FlyteContextManager.current_context(), self, type(self), None
)
return {"path": lv.scalar.blob.uri}

Expand Down
2 changes: 1 addition & 1 deletion flytekit/types/file/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def t2() -> flytekit_typing.FlyteFile["csv"]:
"""

def _serialize(self) -> typing.Dict[str, str]:
lv = FlyteFilePathTransformer().to_literal(FlyteContextManager.current_context(), self, FlyteFile, None)
lv = FlyteFilePathTransformer().to_literal(FlyteContextManager.current_context(), self, type(self), None)
return {"path": lv.scalar.blob.uri}

@classmethod
Expand Down
1 change: 1 addition & 0 deletions flytekit/types/schema/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ class FlyteSchema(SerializableType, DataClassJSONMixin):
"""

def _serialize(self) -> typing.Dict[str, typing.Optional[str]]:
FlyteSchemaTransformer().to_literal(FlyteContextManager.current_context(), self, type(self), None)
return {"remote_path": self.remote_path}

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion flytekit/types/structured/structured_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class (that is just a model, a Python class representation of the protobuf).

def _serialize(self) -> Dict[str, Optional[str]]:
lv = StructuredDatasetTransformerEngine().to_literal(
FlyteContextManager.current_context(), self, StructuredDataset, None
FlyteContextManager.current_context(), self, type(self), None
)
sd = StructuredDataset(uri=lv.scalar.structured_dataset.uri)
sd.file_format = lv.scalar.structured_dataset.metadata.structured_dataset_type.format
Expand Down
27 changes: 26 additions & 1 deletion tests/flytekit/unit/core/test_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import tempfile
from dataclasses import dataclass
from typing import Annotated, List, Dict, Optional

from flytekit.types.schema import FlyteSchema
from flytekit.core.type_engine import TypeEngine
from flytekit.core.context_manager import FlyteContextManager
from flytekit.core.type_engine import DataclassTransformer
Expand Down Expand Up @@ -857,3 +857,28 @@ class NestedFlyteTypes(DataClassJSONMixin):
pv = TypeEngine.to_python_value(ctx, lv, NestedFlyteTypes)
assert isinstance(pv, NestedFlyteTypes)
DataclassTransformer().assert_type(NestedFlyteTypes, pv)

def test_get_literal_type_data_class_json_fail_but_mashumaro_works():
@dataclass
class FlyteTypesWithDataClassJson(DataClassJsonMixin):
flytefile: FlyteFile
flytedir: FlyteDirectory
structured_dataset: StructuredDataset
fs: FlyteSchema

@dataclass
class NestedFlyteTypesWithDataClassJson(DataClassJsonMixin):
flytefile: FlyteFile
flytedir: FlyteDirectory
structured_dataset: StructuredDataset
flyte_types: FlyteTypesWithDataClassJson
fs: FlyteSchema
flyte_types: FlyteTypesWithDataClassJson
list_flyte_types: List[FlyteTypesWithDataClassJson]
dict_flyte_types: Dict[str, FlyteTypesWithDataClassJson]
flyte_types: FlyteTypesWithDataClassJson
optional_flyte_types: Optional[FlyteTypesWithDataClassJson] = None

transformer = DataclassTransformer()
lt = transformer.get_literal_type(NestedFlyteTypesWithDataClassJson)
assert lt.metadata is not None
9 changes: 2 additions & 7 deletions tests/flytekit/unit/core/test_type_delayed.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import dataclasses
import typing
from dataclasses import dataclass

Expand All @@ -21,13 +22,7 @@ class Foo(DataClassJsonMixin):
def test_jsondc_schemaize():
lt = TypeEngine.to_literal_type(Foo)
pt = TypeEngine.guess_python_type(lt)

# When postponed annotations are enabled, dataclass_json will not work and we'll end up with a
# schemaless generic.
# This test basically tests the broken behavior. Remove this test if
# https://github.com/lovasoa/marshmallow_dataclass/issues/13 is ever fixed.
assert pt is dict

assert dataclasses.is_dataclass(pt)

def test_structured_dataset():
ctx = context_manager.FlyteContext.current_context()
Expand Down

0 comments on commit 2e28b5a

Please sign in to comment.