Skip to content

Commit

Permalink
Tentative support for TypedDict
Browse files Browse the repository at this point in the history
  • Loading branch information
omry committed Jan 16, 2021
1 parent 32f2ab4 commit dcb9cee
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 15 deletions.
1 change: 1 addition & 0 deletions news/473.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add tentative support for typing.TypedDict
7 changes: 5 additions & 2 deletions omegaconf/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,8 +422,11 @@ def is_dict_annotation(type_: Any) -> bool:
origin = getattr(type_, "__origin__", None)
if sys.version_info < (3, 7, 0):
return origin is Dict or type_ is Dict # pragma: no cover
else:
return origin is dict # pragma: no cover
else: # pragma: no cover
# type_dict is a bit hard to detect.
# this support is tentative, if it eventually causes issues in other areas it may be dropped.
typed_dict = hasattr(type_, "__base__") and type_.__base__ == dict
return origin is dict or typed_dict


def is_list_annotation(type_: Any) -> bool:
Expand Down
21 changes: 19 additions & 2 deletions tests/structured_conf/data/attr_classes.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
from typing import Any, Dict, List, Optional, Tuple, Union

import attr
Expand All @@ -6,6 +7,9 @@
from omegaconf import II, MISSING, SI
from tests import Color

if sys.version_info >= (3, 8): # pragma: no cover
from typing import TypedDict

# attr is a dependency of pytest which means it's always available when testing with pytest.
pytest.importorskip("attr")

Expand All @@ -20,6 +24,12 @@ def __eq__(self, other: Any) -> Any:
return False


if sys.version_info >= (3, 8): # pragma: no cover

class TypedDictSubclass(TypedDict):
foo: str


@attr.s(auto_attribs=True)
class StructuredWithInvalidField:
bar: NotStructuredConfig = NotStructuredConfig()
Expand Down Expand Up @@ -310,15 +320,22 @@ class ContainsFrozen:


@attr.s(auto_attribs=True)
class WithTypedList:
class WithListField:
list: List[int] = [1, 2, 3]


@attr.s(auto_attribs=True)
class WithTypedDict:
class WithDictField:
dict: Dict[str, int] = {"foo": 10, "bar": 20}


if sys.version_info >= (3, 8): # pragma: no cover

@attr.s(auto_attribs=True)
class WithTypedDictField:
dict: TypedDictSubclass


@attr.s(auto_attribs=True)
class ErrorDictObjectKey:
# invalid dict key, must be str
Expand Down
21 changes: 19 additions & 2 deletions tests/structured_conf/data/dataclasses.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import dataclasses
import sys
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple, Union

Expand All @@ -7,6 +8,9 @@
from omegaconf import II, MISSING, SI
from tests import Color

if sys.version_info >= (3, 8): # pragma: no cover
from typing import TypedDict

# skip test if dataclasses are not available
pytest.importorskip("dataclasses")

Expand All @@ -21,6 +25,12 @@ def __eq__(self, other: Any) -> Any:
return False


if sys.version_info >= (3, 8): # pragma: no cover

class TypedDictSubclass(TypedDict):
foo: str


@dataclass
class StructuredWithInvalidField:
bar: NotStructuredConfig = NotStructuredConfig()
Expand Down Expand Up @@ -313,15 +323,22 @@ class ContainsFrozen:


@dataclass
class WithTypedList:
class WithListField:
list: List[int] = field(default_factory=lambda: [1, 2, 3])


@dataclass
class WithTypedDict:
class WithDictField:
dict: Dict[str, int] = field(default_factory=lambda: {"foo": 10, "bar": 20})


if sys.version_info >= (3, 8): # pragma: no cover

@dataclass
class WithTypedDictField:
dict: TypedDictSubclass


@dataclass
class ErrorDictObjectKey:
# invalid dict key, must be str
Expand Down
26 changes: 17 additions & 9 deletions tests/structured_conf/test_structured_config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
from enum import Enum
from importlib import import_module
from typing import Any, Dict, List, Optional, Union
Expand Down Expand Up @@ -433,9 +434,9 @@ def test_optional(self, class_type: str, tested_type: str) -> None:
conf.with_default = None
assert conf.with_default is None

def test_typed_list(self, class_type: str) -> None:
def test_list_field(self, class_type: str) -> None:
module: Any = import_module(class_type)
input_ = module.WithTypedList
input_ = module.WithListField
conf = OmegaConf.structured(input_)
with pytest.raises(ValidationError):
conf.list[0] = "fail"
Expand All @@ -447,28 +448,35 @@ def test_typed_list(self, class_type: str) -> None:
cfg2 = OmegaConf.create({"list": ["fail"]})
OmegaConf.merge(conf, cfg2)

def test_typed_dict(self, class_type: str) -> None:
def test_dict_field(self, class_type: str) -> None:
module: Any = import_module(class_type)
input_ = module.WithTypedDict
input_ = module.WithDictField
conf = OmegaConf.structured(input_)
with pytest.raises(ValidationError):
conf.dict["foo"] = "fail"

with pytest.raises(ValidationError):
OmegaConf.merge(conf, OmegaConf.create({"dict": {"foo": "fail"}}))

@pytest.mark.skipif(sys.version_info < (3, 8), reason="requires Python 3.8 or newer") # type: ignore
def test_typed_dict_field(self, class_type: str) -> None:
module: Any = import_module(class_type)
input_ = module.WithTypedDictField
conf = OmegaConf.structured(input_(dict={"foo": "bar"}))
assert conf.dict["foo"] == "bar"

def test_merged_type1(self, class_type: str) -> None:
# Test that the merged type is that of the last merged config
module: Any = import_module(class_type)
input_ = module.WithTypedDict
input_ = module.WithDictField
conf = OmegaConf.structured(input_)
res = OmegaConf.merge(OmegaConf.create(), conf)
assert OmegaConf.get_type(res) == input_

def test_merged_type2(self, class_type: str) -> None:
# Test that the merged type is that of the last merged config
module: Any = import_module(class_type)
input_ = module.WithTypedDict
input_ = module.WithDictField
conf = OmegaConf.structured(input_)
res = OmegaConf.merge(conf, {"dict": {"foo": 99}})
assert OmegaConf.get_type(res) == input_
Expand Down Expand Up @@ -557,19 +565,19 @@ def test_merge_dict_with_correct_type(self, class_type: str) -> None:
res = OmegaConf.merge(cfg, {"dict": {"foo": user}})
assert res.dict == {"foo": user}

def test_typed_dict_key_error(self, class_type: str) -> None:
def test_dict_field_key_type_error(self, class_type: str) -> None:
module: Any = import_module(class_type)
input_ = module.ErrorDictObjectKey
with pytest.raises(KeyValidationError):
OmegaConf.structured(input_)

def test_typed_dict_value_error(self, class_type: str) -> None:
def test_dict_field_value_type_error(self, class_type: str) -> None:
module: Any = import_module(class_type)
input_ = module.ErrorDictUnsupportedValue
with pytest.raises(ValidationError):
OmegaConf.structured(input_)

def test_typed_list_value_error(self, class_type: str) -> None:
def test_list_field_value_type_error(self, class_type: str) -> None:
module: Any = import_module(class_type)
input_ = module.ErrorListUnsupportedValue
with pytest.raises(ValidationError):
Expand Down

0 comments on commit dcb9cee

Please sign in to comment.