From dcb9cee87b56db87f5eb1a5761823ed18bc898ef Mon Sep 17 00:00:00 2001 From: Omry Yadan Date: Sat, 16 Jan 2021 10:56:50 -0800 Subject: [PATCH] Tentative support for TypedDict --- news/473.feature | 1 + omegaconf/_utils.py | 7 +++-- tests/structured_conf/data/attr_classes.py | 21 +++++++++++++-- tests/structured_conf/data/dataclasses.py | 21 +++++++++++++-- .../structured_conf/test_structured_config.py | 26 ++++++++++++------- 5 files changed, 61 insertions(+), 15 deletions(-) create mode 100644 news/473.feature diff --git a/news/473.feature b/news/473.feature new file mode 100644 index 000000000..09a46d524 --- /dev/null +++ b/news/473.feature @@ -0,0 +1 @@ +Add tentative support for typing.TypedDict diff --git a/omegaconf/_utils.py b/omegaconf/_utils.py index 16b0cbf8d..e5b3adea7 100644 --- a/omegaconf/_utils.py +++ b/omegaconf/_utils.py @@ -422,8 +422,11 @@ def is_dict_annotation(type_: Any) -> bool: origin = getattr(type_, "__origin__", None) if sys.version_info < (3, 7, 0): return origin is Dict or type_ is Dict # pragma: no cover - else: - return origin is dict # pragma: no cover + else: # pragma: no cover + # type_dict is a bit hard to detect. + # this support is tentative, if it eventually causes issues in other areas it may be dropped. + typed_dict = hasattr(type_, "__base__") and type_.__base__ == dict + return origin is dict or typed_dict def is_list_annotation(type_: Any) -> bool: diff --git a/tests/structured_conf/data/attr_classes.py b/tests/structured_conf/data/attr_classes.py index e2f286ba2..29f710e05 100644 --- a/tests/structured_conf/data/attr_classes.py +++ b/tests/structured_conf/data/attr_classes.py @@ -1,3 +1,4 @@ +import sys from typing import Any, Dict, List, Optional, Tuple, Union import attr @@ -6,6 +7,9 @@ from omegaconf import II, MISSING, SI from tests import Color +if sys.version_info >= (3, 8): # pragma: no cover + from typing import TypedDict + # attr is a dependency of pytest which means it's always available when testing with pytest. pytest.importorskip("attr") @@ -20,6 +24,12 @@ def __eq__(self, other: Any) -> Any: return False +if sys.version_info >= (3, 8): # pragma: no cover + + class TypedDictSubclass(TypedDict): + foo: str + + @attr.s(auto_attribs=True) class StructuredWithInvalidField: bar: NotStructuredConfig = NotStructuredConfig() @@ -310,15 +320,22 @@ class ContainsFrozen: @attr.s(auto_attribs=True) -class WithTypedList: +class WithListField: list: List[int] = [1, 2, 3] @attr.s(auto_attribs=True) -class WithTypedDict: +class WithDictField: dict: Dict[str, int] = {"foo": 10, "bar": 20} +if sys.version_info >= (3, 8): # pragma: no cover + + @attr.s(auto_attribs=True) + class WithTypedDictField: + dict: TypedDictSubclass + + @attr.s(auto_attribs=True) class ErrorDictObjectKey: # invalid dict key, must be str diff --git a/tests/structured_conf/data/dataclasses.py b/tests/structured_conf/data/dataclasses.py index 867309164..0e8f99b64 100644 --- a/tests/structured_conf/data/dataclasses.py +++ b/tests/structured_conf/data/dataclasses.py @@ -1,4 +1,5 @@ import dataclasses +import sys from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Tuple, Union @@ -7,6 +8,9 @@ from omegaconf import II, MISSING, SI from tests import Color +if sys.version_info >= (3, 8): # pragma: no cover + from typing import TypedDict + # skip test if dataclasses are not available pytest.importorskip("dataclasses") @@ -21,6 +25,12 @@ def __eq__(self, other: Any) -> Any: return False +if sys.version_info >= (3, 8): # pragma: no cover + + class TypedDictSubclass(TypedDict): + foo: str + + @dataclass class StructuredWithInvalidField: bar: NotStructuredConfig = NotStructuredConfig() @@ -313,15 +323,22 @@ class ContainsFrozen: @dataclass -class WithTypedList: +class WithListField: list: List[int] = field(default_factory=lambda: [1, 2, 3]) @dataclass -class WithTypedDict: +class WithDictField: dict: Dict[str, int] = field(default_factory=lambda: {"foo": 10, "bar": 20}) +if sys.version_info >= (3, 8): # pragma: no cover + + @dataclass + class WithTypedDictField: + dict: TypedDictSubclass + + @dataclass class ErrorDictObjectKey: # invalid dict key, must be str diff --git a/tests/structured_conf/test_structured_config.py b/tests/structured_conf/test_structured_config.py index 25342ceaa..a6e6b1759 100644 --- a/tests/structured_conf/test_structured_config.py +++ b/tests/structured_conf/test_structured_config.py @@ -1,3 +1,4 @@ +import sys from enum import Enum from importlib import import_module from typing import Any, Dict, List, Optional, Union @@ -433,9 +434,9 @@ def test_optional(self, class_type: str, tested_type: str) -> None: conf.with_default = None assert conf.with_default is None - def test_typed_list(self, class_type: str) -> None: + def test_list_field(self, class_type: str) -> None: module: Any = import_module(class_type) - input_ = module.WithTypedList + input_ = module.WithListField conf = OmegaConf.structured(input_) with pytest.raises(ValidationError): conf.list[0] = "fail" @@ -447,9 +448,9 @@ def test_typed_list(self, class_type: str) -> None: cfg2 = OmegaConf.create({"list": ["fail"]}) OmegaConf.merge(conf, cfg2) - def test_typed_dict(self, class_type: str) -> None: + def test_dict_field(self, class_type: str) -> None: module: Any = import_module(class_type) - input_ = module.WithTypedDict + input_ = module.WithDictField conf = OmegaConf.structured(input_) with pytest.raises(ValidationError): conf.dict["foo"] = "fail" @@ -457,10 +458,17 @@ def test_typed_dict(self, class_type: str) -> None: with pytest.raises(ValidationError): OmegaConf.merge(conf, OmegaConf.create({"dict": {"foo": "fail"}})) + @pytest.mark.skipif(sys.version_info < (3, 8), reason="requires Python 3.8 or newer") # type: ignore + def test_typed_dict_field(self, class_type: str) -> None: + module: Any = import_module(class_type) + input_ = module.WithTypedDictField + conf = OmegaConf.structured(input_(dict={"foo": "bar"})) + assert conf.dict["foo"] == "bar" + def test_merged_type1(self, class_type: str) -> None: # Test that the merged type is that of the last merged config module: Any = import_module(class_type) - input_ = module.WithTypedDict + input_ = module.WithDictField conf = OmegaConf.structured(input_) res = OmegaConf.merge(OmegaConf.create(), conf) assert OmegaConf.get_type(res) == input_ @@ -468,7 +476,7 @@ def test_merged_type1(self, class_type: str) -> None: def test_merged_type2(self, class_type: str) -> None: # Test that the merged type is that of the last merged config module: Any = import_module(class_type) - input_ = module.WithTypedDict + input_ = module.WithDictField conf = OmegaConf.structured(input_) res = OmegaConf.merge(conf, {"dict": {"foo": 99}}) assert OmegaConf.get_type(res) == input_ @@ -557,19 +565,19 @@ def test_merge_dict_with_correct_type(self, class_type: str) -> None: res = OmegaConf.merge(cfg, {"dict": {"foo": user}}) assert res.dict == {"foo": user} - def test_typed_dict_key_error(self, class_type: str) -> None: + def test_dict_field_key_type_error(self, class_type: str) -> None: module: Any = import_module(class_type) input_ = module.ErrorDictObjectKey with pytest.raises(KeyValidationError): OmegaConf.structured(input_) - def test_typed_dict_value_error(self, class_type: str) -> None: + def test_dict_field_value_type_error(self, class_type: str) -> None: module: Any = import_module(class_type) input_ = module.ErrorDictUnsupportedValue with pytest.raises(ValidationError): OmegaConf.structured(input_) - def test_typed_list_value_error(self, class_type: str) -> None: + def test_list_field_value_type_error(self, class_type: str) -> None: module: Any = import_module(class_type) input_ = module.ErrorListUnsupportedValue with pytest.raises(ValidationError):