Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for InitVar #495

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 17 additions & 14 deletions dataclasses_json/core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import inspect
import json
import sys
import warnings
Expand Down Expand Up @@ -145,7 +146,7 @@ def _decode_dataclass(cls, kvs, infer_missing):
return kvs
overrides = _user_overrides_or_exts(cls)
kvs = {} if kvs is None and infer_missing else kvs
field_names = [field.name for field in fields(cls)]
field_names = set(cls.__dataclass_fields__.keys())
decode_names = _decode_letter_case_overrides(field_names, overrides)
kvs = {decode_names.get(k, k): v for k, v in kvs.items()}
missing_fields = {field for field in fields(cls) if field.name not in kvs}
Expand All @@ -163,18 +164,20 @@ def _decode_dataclass(cls, kvs, infer_missing):

init_kwargs = {}
types = get_type_hints(cls)
for field in fields(cls):
constructor_args = set(inspect.signature(cls).parameters.keys())

for field_name in field_names:
# The field should be skipped from being added
# to init_kwargs as it's not intended as a constructor argument.
if not field.init:
if field_name not in constructor_args:
continue

field_value = kvs[field.name]
field_type = types[field.name]
field_value = kvs[field_name]
field_type = types[field_name]
if field_value is None:
if not _is_optional(field_type):
warning = (
f"value of non-optional type {field.name} detected "
f"value of non-optional type {field_name} detected "
f"when decoding {cls.__name__}"
)
if infer_missing:
Expand All @@ -188,7 +191,7 @@ def _decode_dataclass(cls, kvs, infer_missing):
warnings.warn(
f"'NoneType' object {warning}.", RuntimeWarning
)
init_kwargs[field.name] = field_value
init_kwargs[field_name] = field_value
continue

while True:
Expand All @@ -197,13 +200,13 @@ def _decode_dataclass(cls, kvs, infer_missing):

field_type = field_type.__supertype__

if (field.name in overrides
and overrides[field.name].decoder is not None):
if (field_name in overrides
and overrides[field_name].decoder is not None):
# FIXME hack
if field_type is type(field_value):
init_kwargs[field.name] = field_value
init_kwargs[field_name] = field_value
else:
init_kwargs[field.name] = overrides[field.name].decoder(
init_kwargs[field_name] = overrides[field_name].decoder(
field_value)
elif is_dataclass(field_type):
# FIXME this is a band-aid to deal with the value already being
Expand All @@ -215,13 +218,13 @@ def _decode_dataclass(cls, kvs, infer_missing):
else:
value = _decode_dataclass(field_type, field_value,
infer_missing)
init_kwargs[field.name] = value
init_kwargs[field_name] = value
elif _is_supported_generic(field_type) and field_type != str:
init_kwargs[field.name] = _decode_generic(field_type,
init_kwargs[field_name] = _decode_generic(field_type,
field_value,
infer_missing)
else:
init_kwargs[field.name] = _support_extended_types(field_type,
init_kwargs[field_name] = _support_extended_types(field_type,
field_value)

return cls(**init_kwargs)
Expand Down
26 changes: 26 additions & 0 deletions tests/test_init_var.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from dataclasses import InitVar, dataclass
from typing import Optional

import pytest

from dataclasses_json import DataClassJsonMixin


@dataclass
class A(DataClassJsonMixin):
a_init: InitVar[int]
_a: Optional[int] = None

def __post_init__(self, a_init: int):
self._a = a_init


class TestEncoder:
def test_init_var(self):
assert A(a_init=1).to_dict() == {'_a': 1}


class TestDecoder:
def test_init_var(self):
result = A.from_dict({'a_init': 1})
assert result._a == 1