From 240ae7124877f9c7088dcb0084820a50c3e1a5f7 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Thu, 2 Dec 2021 00:00:10 -0600 Subject: [PATCH] get_dataclass_data: branch on dataclass vs dataclass instance --- news/831.bugfix | 1 + omegaconf/_utils.py | 16 ++++++--- tests/structured_conf/data/attr_classes.py | 10 ++++++ tests/structured_conf/data/dataclasses.py | 10 ++++++ .../structured_conf/test_structured_config.py | 36 ++++++++++++++++++- 5 files changed, 67 insertions(+), 6 deletions(-) create mode 100644 news/831.bugfix diff --git a/news/831.bugfix b/news/831.bugfix new file mode 100644 index 000000000..e01b68f64 --- /dev/null +++ b/news/831.bugfix @@ -0,0 +1 @@ +Fix bugs related to creation of structured configs from dataclasses having fields with a default_factory diff --git a/omegaconf/_utils.py b/omegaconf/_utils.py index 5450c7b7d..9a38c1bc0 100644 --- a/omegaconf/_utils.py +++ b/omegaconf/_utils.py @@ -336,6 +336,7 @@ def get_dataclass_data( flags = {"allow_objects": allow_objects} if allow_objects is not None else {} d = {} + is_type = isinstance(obj, type) obj_type = get_type_of(obj) dummy_parent = OmegaConf.create({}, flags=flags) dummy_parent._metadata.object_type = obj_type @@ -344,13 +345,18 @@ def get_dataclass_data( name = field.name is_optional, type_ = _resolve_optional(resolved_hints[field.name]) type_ = _resolve_forward(type_, obj.__module__) + has_default = field.default != dataclasses.MISSING + has_default_factory = field.default_factory != dataclasses.MISSING # type: ignore - value = getattr(obj, name, MISSING) - if value in (MISSING, dataclasses.MISSING): - if field.default_factory == dataclasses.MISSING: # type: ignore - value = MISSING - else: + if not is_type: + value = getattr(obj, name) + else: + if has_default: + value = field.default + elif has_default_factory: value = field.default_factory() # type: ignore + else: + value = MISSING if _is_union(type_): e = ConfigValueError( diff --git a/tests/structured_conf/data/attr_classes.py b/tests/structured_conf/data/attr_classes.py index 50e6cf181..b42016db4 100644 --- a/tests/structured_conf/data/attr_classes.py +++ b/tests/structured_conf/data/attr_classes.py @@ -608,6 +608,16 @@ class ChildContainers(ParentContainers): list1: List[int] = [1, 2, 3] dict: Dict[str, Any] = {"a": 5, "b": 6} + @attr.s(auto_attribs=True) + class ParentNoDefaultFactory: + no_default_to_list: Any + int_to_list: Any = 1 + + @attr.s(auto_attribs=True) + class ChildWithDefaultFactory(ParentNoDefaultFactory): + no_default_to_list: Any = ["hi"] + int_to_list: Any = ["hi"] + @attr.s(auto_attribs=True) class HasInitFalseFields: diff --git a/tests/structured_conf/data/dataclasses.py b/tests/structured_conf/data/dataclasses.py index 02b5c8d4a..987f224a7 100644 --- a/tests/structured_conf/data/dataclasses.py +++ b/tests/structured_conf/data/dataclasses.py @@ -629,6 +629,16 @@ class ChildContainers(ParentContainers): list1: List[int] = field(default_factory=lambda: [1, 2, 3]) dict: Dict[str, Any] = field(default_factory=lambda: {"a": 5, "b": 6}) + @dataclass + class ParentNoDefaultFactory: + no_default_to_list: Any + int_to_list: Any = 1 + + @dataclass + class ChildWithDefaultFactory(ParentNoDefaultFactory): + no_default_to_list: Any = field(default_factory=lambda: ["hi"]) + int_to_list: Any = field(default_factory=lambda: ["hi"]) + @dataclass class HasInitFalseFields: diff --git a/tests/structured_conf/test_structured_config.py b/tests/structured_conf/test_structured_config.py index 493003c46..c6631ea35 100644 --- a/tests/structured_conf/test_structured_config.py +++ b/tests/structured_conf/test_structured_config.py @@ -1,7 +1,7 @@ import inspect import sys from importlib import import_module -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional from pytest import fixture, mark, param, raises @@ -211,6 +211,13 @@ def validate(cfg: DictConfig) -> None: conf2 = OmegaConf.structured(module.ConfigWithList()) validate(conf2) + def test_config_with_list_nondefault_values(self, module: Any) -> None: + conf1 = OmegaConf.structured(module.ConfigWithList(list1=[4, 5, 6])) + assert conf1.list1 == [4, 5, 6] + + conf2 = OmegaConf.structured(module.ConfigWithList(list1=MISSING)) + assert OmegaConf.is_missing(conf2, "list1") + def test_assignment_to_nested_structured_config(self, module: Any) -> None: conf = OmegaConf.structured(module.NestedConfig) with raises(ValidationError): @@ -236,6 +243,13 @@ def validate(cfg: DictConfig) -> None: conf2 = OmegaConf.structured(module.ConfigWithDict()) validate(conf2) + def test_config_with_dict_nondefault_values(self, module: Any) -> None: + conf1 = OmegaConf.structured(module.ConfigWithDict(dict1={"baz": "qux"})) + assert conf1.dict1 == {"baz": "qux"} + + conf2 = OmegaConf.structured(module.ConfigWithDict(dict1=MISSING)) + assert OmegaConf.is_missing(conf2, "dict1") + def test_structured_config_struct_behavior(self, module: Any) -> None: def validate(cfg: DictConfig) -> None: assert not OmegaConf.is_struct(cfg) @@ -1231,6 +1245,26 @@ def test_container_inheritance(self, module: Any) -> None: assert OmegaConf.is_missing(parent, "dict") assert child.dict == {"a": 5, "b": 6} + @mark.parametrize( + "create_fn", + [ + param(lambda cls: OmegaConf.structured(cls), id="create_from_class"), + param(lambda cls: OmegaConf.structured(cls()), id="create_from_instance"), + ], + ) + def test_subclass_using_default_factory( + self, module: Any, create_fn: Callable[[Any], DictConfig] + ) -> None: + """ + When a structured config field has a default and a subclass defines a + default_factory for the same field, ensure that the DictConfig created + from the subclass uses the subclass' default_factory (not the parent + class' default). + """ + cfg = create_fn(module.StructuredSubclass.ChildWithDefaultFactory) + assert cfg.no_default_to_list == ["hi"] + assert cfg.int_to_list == ["hi"] + class TestNestedContainers: @mark.parametrize(