Skip to content

Commit

Permalink
Improve flux json serializer for dataclass with post_init (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
pshanmukrao authored Mar 19, 2024
1 parent d49ffef commit a7f7271
Showing 1 changed file with 37 additions and 3 deletions.
40 changes: 37 additions & 3 deletions flux_dev_tools/server/serialization.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import base64
import io
import json
from dataclasses import is_dataclass
from datetime import date, datetime
from decimal import Decimal
from enum import Enum
from typing import Union, get_args, get_origin
from typing import Any, Union, get_args, get_origin


class FluxJSONEncoder(json.JSONEncoder):
Expand Down Expand Up @@ -50,7 +51,7 @@ def default(self, obj):
if isinstance(obj, dict):
return {key: self.default(value) for key, value in obj.items()}
if hasattr(obj, "__dict__"):
serialized_dict = {}
serialized_dict = {"target_type": repr(type(obj))}
for key, value in obj.__dict__.items():
serialized_dict[key] = json.dumps(value, cls=FluxJSONEncoder)
return serialized_dict
Expand Down Expand Up @@ -118,13 +119,23 @@ def decode(self, s, *args, **kwargs):
return self._convert_object(obj, self.target_type)

def _convert_object(self, obj, target_type):
passed_in_target_type = obj.pop("target_type", None) if isinstance(obj, dict) else None
if get_origin(target_type) is Union:
for union_type in get_args(target_type):
if repr(union_type) == passed_in_target_type:
target_type = union_type
break
if get_origin(target_type) is Union:
target_type = type(obj)
if get_origin(target_type) is list:
target_type = get_args(target_type)[0]
if get_origin(target_type) is tuple and obj is not None:
return tuple(self._convert_object(item, get_args(target_type)[i]) for i, item in enumerate(obj))
if isinstance(obj, list):
return [self._convert_object(item, target_type) for item in obj]

if target_type is not None:
if obj is None or type(obj) is get_origin(target_type) or type(obj) is target_type:
if obj is None or type(obj) is target_type:
return obj
elif target_type is Decimal:
return Decimal(obj)
Expand All @@ -144,6 +155,18 @@ def _convert_object(self, obj, target_type):
return int(obj)
elif target_type is str:
return str(obj)
elif target_type is dict or get_origin(target_type) is dict:
deserialized_dict = {}
for key, value in obj.items():
value_target_type = None
if (
hasattr(target_type, "__args__")
and len(target_type.__args__) == 2
and target_type.__args__[1] is not Any
):
value_target_type = target_type.__args__[1]
deserialized_dict[key] = self._convert_object(value, target_type=value_target_type)
return deserialized_dict
else:
if not obj:
return target_type()
Expand All @@ -154,11 +177,22 @@ def _convert_object(self, obj, target_type):
if self.is_optional_type(expected_type):
expected_type = expected_type.__args__[0]
deserialized_dict[key] = json.loads(value, cls=FluxJSONDecoder, target_type=expected_type)
if is_dataclass(target_type):
return target_type(**deserialized_dict)
instance = target_type()
for key, value in target_type.__annotations__.items():
setattr(instance, key, deserialized_dict.get(key, None))
return instance

if isinstance(obj, str):
try:
return datetime.fromisoformat(obj)
except ValueError:
pass

if isinstance(obj, dict):
return {key: self._convert_object(value, target_type) for key, value in obj.items()}

return obj

def is_optional_type(self, typ):
Expand Down

0 comments on commit a7f7271

Please sign in to comment.