From 15d9abb0570536b13c21b20dcd2d75faeef2d20c Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Tue, 15 Dec 2020 02:08:27 -0600 Subject: [PATCH 01/85] to_container: instantiate_structured_configs flag --- omegaconf/basecontainer.py | 13 +++ omegaconf/omegaconf.py | 2 + tests/test_base_config.py | 174 +++++++++++++++++++++++++++++++++++++ 3 files changed, 189 insertions(+) diff --git a/omegaconf/basecontainer.py b/omegaconf/basecontainer.py index 0a46cb5c8..3d9689eda 100644 --- a/omegaconf/basecontainer.py +++ b/omegaconf/basecontainer.py @@ -193,10 +193,16 @@ def _to_content( resolve: bool, enum_to_str: bool = False, exclude_structured_configs: bool = False, + instantiate_structured_configs: bool = False, ) -> Union[None, Any, str, Dict[DictKeyType, Any], List[Any]]: from .dictconfig import DictConfig from .listconfig import ListConfig + if exclude_structured_configs and instantiate_structured_configs: + raise ValueError( + "Cannot both exclude and and instantiate structured configs" + ) + def convert(val: Node) -> Any: value = val._value() if enum_to_str and isinstance(value, Enum): @@ -233,9 +239,15 @@ def convert(val: Node) -> Any: resolve=resolve, enum_to_str=enum_to_str, exclude_structured_configs=exclude_structured_configs, + instantiate_structured_configs=instantiate_structured_configs, ) else: retdict[key] = convert(node) + if ( + conf._metadata.object_type is not None + and instantiate_structured_configs + ): + return conf._metadata.object_type(**retdict) return retdict elif isinstance(conf, ListConfig): retlist: List[Any] = [] @@ -253,6 +265,7 @@ def convert(val: Node) -> Any: resolve=resolve, enum_to_str=enum_to_str, exclude_structured_configs=exclude_structured_configs, + instantiate_structured_configs=instantiate_structured_configs, ) retlist.append(item) else: diff --git a/omegaconf/omegaconf.py b/omegaconf/omegaconf.py index ee5b7c244..4ca21fbb8 100644 --- a/omegaconf/omegaconf.py +++ b/omegaconf/omegaconf.py @@ -580,6 +580,7 @@ def to_container( resolve: bool = False, enum_to_str: bool = False, exclude_structured_configs: bool = False, + instantiate_structured_configs: bool = False, ) -> Union[Dict[DictKeyType, Any], List[Any], None, str]: """ Resursively converts an OmegaConf config to a primitive container (dict or list). @@ -600,6 +601,7 @@ def to_container( resolve=resolve, enum_to_str=enum_to_str, exclude_structured_configs=exclude_structured_configs, + instantiate_structured_configs=instantiate_structured_configs, ) @staticmethod diff --git a/tests/test_base_config.py b/tests/test_base_config.py index 482b5e0de..be9ef6031 100644 --- a/tests/test_base_config.py +++ b/tests/test_base_config.py @@ -80,6 +80,180 @@ def test_replace_value_node_type_with_another( assert c[key] == value._value() +@pytest.mark.parametrize( + "input_", + [ + pytest.param([1, 2, 3], id="list"), + pytest.param([1, 2, {"a": 3}], id="dict_in_list"), + pytest.param([1, 2, [10, 20]], id="list_in_list"), + pytest.param({"b": {"b": 10}}, id="dict_in_dict"), + pytest.param({"b": [False, 1, "2", 3.0, Color.RED]}, id="list_in_dict"), + pytest.param({"b": DictConfig(content=None)}, id="none_dictconfig"), + pytest.param({"b": ListConfig(content=None)}, id="none_listconfig"), + pytest.param({"b": DictConfig(content="???")}, id="missing_dictconfig"), + pytest.param({"b": ListConfig(content="???")}, id="missing_listconfig"), + ], +) +def test_to_container_returns_primitives(input_: Any) -> None: + def assert_container_with_primitives(item: Any) -> None: + if isinstance(item, list): + for v in item: + assert_container_with_primitives(v) + elif isinstance(item, dict): + for _k, v in item.items(): + assert_container_with_primitives(v) + else: + assert isinstance(item, (int, float, str, bool, type(None), Enum)) + + c = OmegaConf.create(input_) + res = OmegaConf.to_container(c, resolve=True) + assert_container_with_primitives(res) + + +@pytest.mark.parametrize( + "cfg,ex_false,ex_true", + [ + pytest.param( + {"user": User(age=7, name="Bond")}, + {"user": {"name": "Bond", "age": 7}}, + {"user": User(age=7, name="Bond")}, + ), + pytest.param( + [1, User(age=7, name="Bond")], + [1, {"name": "Bond", "age": 7}], + [1, User(age=7, name="Bond")], + ), + pytest.param( + {"users": [User(age=1, name="a"), User(age=2, name="b")]}, + {"users": [{"age": 1, "name": "a"}, {"age": 2, "name": "b"}]}, + {"users": [User(age=1, name="a"), User(age=2, name="b")]}, + ), + ], +) +def test_exclude_structured_configs(cfg: Any, ex_false: Any, ex_true: Any) -> None: + cfg = OmegaConf.create(cfg) + ret1 = OmegaConf.to_container(cfg, exclude_structured_configs=False) + assert ret1 == ex_false + + ret1 = OmegaConf.to_container(cfg, instantiate_structured_configs=True) + assert ret1 == ex_true + + ret1 = OmegaConf.to_container(cfg, exclude_structured_configs=True) + assert ret1 == ex_true + + +@pytest.mark.parametrize( + "src, expected, expected_with_resolve", + [ + pytest.param([], None, None, id="empty_list"), + pytest.param([1, 2, 3], None, None, id="list"), + pytest.param([None], None, None, id="list_with_none"), + pytest.param([1, "${0}", 3], None, [1, 1, 3], id="list_with_inter"), + pytest.param({}, None, None, id="empty_dict"), + pytest.param({"foo": "bar"}, None, None, id="dict"), + pytest.param( + {"foo": "${bar}", "bar": "zonk"}, + None, + {"foo": "zonk", "bar": "zonk"}, + id="dict_with_inter", + ), + pytest.param({"foo": None}, None, None, id="dict_with_none"), + pytest.param({"foo": "???"}, None, None, id="dict_missing_value"), + pytest.param({"foo": None}, None, None, id="dict_none_value"), + # containers + pytest.param( + {"foo": DictConfig(is_optional=True, content=None)}, + {"foo": None}, + None, + id="dict_none_dictconfig", + ), + pytest.param( + {"foo": DictConfig(content="???")}, + {"foo": "???"}, + None, + id="dict_missing_dictconfig", + ), + pytest.param( + {"foo": DictConfig(content="${bar}"), "bar": 10}, + {"foo": "${bar}", "bar": 10}, + {"foo": 10, "bar": 10}, + id="dict_inter_dictconfig", + ), + pytest.param( + {"foo": ListConfig(content="???")}, + {"foo": "???"}, + None, + id="dict_missing_listconfig", + ), + pytest.param( + {"foo": ListConfig(is_optional=True, content=None)}, + {"foo": None}, + None, + id="dict_none_listconfig", + ), + pytest.param( + {"foo": ListConfig(content="${bar}"), "bar": 10}, + {"foo": "${bar}", "bar": 10}, + {"foo": 10, "bar": 10}, + id="dict_inter_listconfig", + ), + ], +) +def test_to_container(src: Any, expected: Any, expected_with_resolve: Any) -> None: + if expected is None: + expected = src + if expected_with_resolve is None: + expected_with_resolve = expected + cfg = OmegaConf.create(src) + container = OmegaConf.to_container(cfg) + assert container == expected + container = OmegaConf.to_container(cfg, instantiate_structured_configs=True) + assert container == expected + container = OmegaConf.to_container(cfg, resolve=True) + assert container == expected_with_resolve + + +def test_to_container_invalid_input() -> None: + with pytest.raises( + ValueError, + match=re.escape("Input cfg is not an OmegaConf config object (dict)"), + ): + OmegaConf.to_container({}) + + +def test_to_container_options_mutually_exclusive() -> None: + with raises(ValueError): + cfg = OmegaConf.create() + OmegaConf.to_container( + cfg, exclude_structured_configs=True, instantiate_structured_configs=True + ) + + +def test_string_interpolation_with_readonly_parent() -> None: + cfg = OmegaConf.create({"a": 10, "b": {"c": "hello_${a}"}}) + OmegaConf.set_readonly(cfg, True) + assert OmegaConf.to_container(cfg, resolve=True) == { + "a": 10, + "b": {"c": "hello_10"}, + } + + +@pytest.mark.parametrize( + "src,expected", + [ + pytest.param(DictConfig(content="${bar}"), "${bar}", id="DictConfig"), + pytest.param( + OmegaConf.create({"foo": DictConfig(content="${bar}")}), + {"foo": "${bar}"}, + id="nested_DictConfig", + ), + ], +) +def test_to_container_missing_inter_no_resolve(src: Any, expected: Any) -> None: + res = OmegaConf.to_container(src, resolve=False) + assert res == expected + + @pytest.mark.parametrize( "input_, is_empty", [ From c4b042aedc0c5eebcf18cb719f3b5a81d3746d9e Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Wed, 16 Dec 2020 14:50:35 -0600 Subject: [PATCH 02/85] add a failing test --- tests/structured_conf/test_structured_config.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/structured_conf/test_structured_config.py b/tests/structured_conf/test_structured_config.py index 6c0109b39..cd5297f96 100644 --- a/tests/structured_conf/test_structured_config.py +++ b/tests/structured_conf/test_structured_config.py @@ -1039,6 +1039,11 @@ def test_str2str_with_field(self, class_type: str) -> None: with pytest.raises(KeyValidationError): cfg[Color.RED] = "fail" + data = OmegaConf.to_container(cfg, instantiate_structured_configs) + assert type(data) == module.DictSubclass.Str2StrWithField + assert data.foo == "bar" + assert data["hello"] == "world" + class TestErrors: def test_usr2str(self, class_type: str) -> None: module: Any = import_module(class_type) From fe47c89ab5ae35b2a058af9b477a7d1794de8630 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Wed, 16 Dec 2020 19:23:10 -0600 Subject: [PATCH 03/85] fix typo --- tests/structured_conf/test_structured_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/structured_conf/test_structured_config.py b/tests/structured_conf/test_structured_config.py index cd5297f96..1c3cf70b3 100644 --- a/tests/structured_conf/test_structured_config.py +++ b/tests/structured_conf/test_structured_config.py @@ -1039,7 +1039,7 @@ def test_str2str_with_field(self, class_type: str) -> None: with pytest.raises(KeyValidationError): cfg[Color.RED] = "fail" - data = OmegaConf.to_container(cfg, instantiate_structured_configs) + data = OmegaConf.to_container(cfg, instantiate_structured_configs=True) assert type(data) == module.DictSubclass.Str2StrWithField assert data.foo == "bar" assert data["hello"] == "world" From c91be999f1f434b16685763763f732b3766b54d8 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Tue, 22 Dec 2020 16:34:20 -0600 Subject: [PATCH 04/85] fix bug --- omegaconf/basecontainer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/omegaconf/basecontainer.py b/omegaconf/basecontainer.py index 3d9689eda..659f59583 100644 --- a/omegaconf/basecontainer.py +++ b/omegaconf/basecontainer.py @@ -243,11 +243,11 @@ def convert(val: Node) -> Any: ) else: retdict[key] = convert(node) - if ( - conf._metadata.object_type is not None - and instantiate_structured_configs + if instantiate_structured_configs and is_structured_config( + conf._metadata.ref_type ): - return conf._metadata.object_type(**retdict) + assert callable(conf._metadata.ref_type) + retdict = conf._metadata.ref_type(**retdict) return retdict elif isinstance(conf, ListConfig): retlist: List[Any] = [] From 59ee9097037bc5d24e8f130afa783b3e8e73a2e1 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Tue, 22 Dec 2020 19:51:45 -0600 Subject: [PATCH 05/85] another bugfix --- omegaconf/basecontainer.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/omegaconf/basecontainer.py b/omegaconf/basecontainer.py index 659f59583..27bca0495 100644 --- a/omegaconf/basecontainer.py +++ b/omegaconf/basecontainer.py @@ -246,8 +246,14 @@ def convert(val: Node) -> Any: if instantiate_structured_configs and is_structured_config( conf._metadata.ref_type ): + # I think that: + # _metadata.ref_type is from the type annotation in the data class, + # _metadata.object_type is the type of the actual object that was + # passed in to omegaconf assert callable(conf._metadata.ref_type) - retdict = conf._metadata.ref_type(**retdict) + assert callable(conf._metadata.object_type) + assert issubclass(conf._metadata.object_type, conf._metadata.ref_type) + retdict = conf._metadata.object_type(**retdict) return retdict elif isinstance(conf, ListConfig): retlist: List[Any] = [] From 14b75be3a2b3c1cb498c8b0b41823ebd91406fbb Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Tue, 22 Dec 2020 20:05:06 -0600 Subject: [PATCH 06/85] solved an issue --- omegaconf/basecontainer.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/omegaconf/basecontainer.py b/omegaconf/basecontainer.py index 27bca0495..70025aa31 100644 --- a/omegaconf/basecontainer.py +++ b/omegaconf/basecontainer.py @@ -254,6 +254,14 @@ def convert(val: Node) -> Any: assert callable(conf._metadata.object_type) assert issubclass(conf._metadata.object_type, conf._metadata.ref_type) retdict = conf._metadata.object_type(**retdict) + elif instantiate_structured_configs and is_structured_config( + conf._metadata.object_type + ): + # This is the case where the type annotation is NOT a dataclass, but the + # object passed in IS a dataclass. + assert callable(conf._metadata.object_type) + assert issubclass(conf._metadata.object_type, conf._metadata.ref_type) + retdict = conf._metadata.object_type(**retdict) return retdict elif isinstance(conf, ListConfig): retlist: List[Any] = [] From 3354bad9ee625bf96318ff280ed423549feed521 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Tue, 22 Dec 2020 21:58:50 -0600 Subject: [PATCH 07/85] add type assert --- omegaconf/basecontainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/omegaconf/basecontainer.py b/omegaconf/basecontainer.py index 70025aa31..74b83061f 100644 --- a/omegaconf/basecontainer.py +++ b/omegaconf/basecontainer.py @@ -259,6 +259,7 @@ def convert(val: Node) -> Any: ): # This is the case where the type annotation is NOT a dataclass, but the # object passed in IS a dataclass. + assert conf._metadata.ref_type is not None assert callable(conf._metadata.object_type) assert issubclass(conf._metadata.object_type, conf._metadata.ref_type) retdict = conf._metadata.object_type(**retdict) From 5cb06ed99ff7ad1e4992637ddb312aec45d20953 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Wed, 23 Dec 2020 00:52:38 -0600 Subject: [PATCH 08/85] updates --- omegaconf/basecontainer.py | 61 +++++++++----- .../structured_conf/test_structured_config.py | 84 ++++++++++++++++++- 2 files changed, 125 insertions(+), 20 deletions(-) diff --git a/omegaconf/basecontainer.py b/omegaconf/basecontainer.py index 74b83061f..3d988ac0c 100644 --- a/omegaconf/basecontainer.py +++ b/omegaconf/basecontainer.py @@ -243,26 +243,49 @@ def convert(val: Node) -> Any: ) else: retdict[key] = convert(node) - if instantiate_structured_configs and is_structured_config( - conf._metadata.ref_type - ): - # I think that: - # _metadata.ref_type is from the type annotation in the data class, - # _metadata.object_type is the type of the actual object that was - # passed in to omegaconf - assert callable(conf._metadata.ref_type) - assert callable(conf._metadata.object_type) - assert issubclass(conf._metadata.object_type, conf._metadata.ref_type) - retdict = conf._metadata.object_type(**retdict) - elif instantiate_structured_configs and is_structured_config( - conf._metadata.object_type + + def _instantiate_structured_config_impl(retdict, object_type): + from ._utils import get_structured_config_data + + object_type_field_names = get_structured_config_data(object_type).keys() + if issubclass(object_type, dict): + # Extending dict as a subclass + + retdict_field_items = { + k: v for k, v in retdict.items() if k in object_type_field_names + } + retdict_nonfield_items = { + k: v + for k, v in retdict.items() + if k not in object_type_field_names + } + result = object_type(**retdict_field_items) + result.update(retdict_nonfield_items) + else: + assert set(retdict.keys()) <= set(object_type_field_names) + result = object_type(**retdict) + return result + + ref_type = conf._metadata.ref_type + object_type = conf._metadata.object_type + # I think that: + # ref_type should be either the type annotation for the value (set by e.g. + # a dataclass field type annotation or a typing.Dict type annotation) or, + # if annotation is available, the type of the value. + # object_type (set in dictconfig.DictConfig._set_value_impl) is the type of + # the value, used, possibly a subclass of ref_type. + if is_structured_config(ref_type): + assert is_structured_config(object_type) + if is_structured_config(ref_type) or is_structured_config(object_type): + assert ref_type is not None + assert object_type is not None + if ref_type is not Any: + assert issubclass(object_type, ref_type) + if instantiate_structured_configs and ( + is_structured_config(ref_type) or is_structured_config(object_type) ): - # This is the case where the type annotation is NOT a dataclass, but the - # object passed in IS a dataclass. - assert conf._metadata.ref_type is not None - assert callable(conf._metadata.object_type) - assert issubclass(conf._metadata.object_type, conf._metadata.ref_type) - retdict = conf._metadata.object_type(**retdict) + retdict = _instantiate_structured_config_impl(retdict, object_type) + return retdict elif isinstance(conf, ListConfig): retlist: List[Any] = [] diff --git a/tests/structured_conf/test_structured_config.py b/tests/structured_conf/test_structured_config.py index 1c3cf70b3..1130ccb4f 100644 --- a/tests/structured_conf/test_structured_config.py +++ b/tests/structured_conf/test_structured_config.py @@ -864,6 +864,86 @@ def test_create_untyped_list(self, class_type: str) -> None: assert cfg.list == [1, 2] assert cfg.opt_list is None + class TestInstantiateStructuredConfigs: + @pytest.fixture + def module(self, class_type: str) -> Any: + module: Any = import_module(class_type) + return module + + def round_trip_to_container(self, input_data: Any) -> Any: + serialized = OmegaConf.create(input_data) + round_tripped = OmegaConf.to_container( + serialized, instantiate_structured_configs=True + ) + return round_tripped + + def test_basic(self, module: Any) -> None: + user = self.round_trip_to_container(module.User()) + assert isinstance(user, module.User) + assert type(user) is module.User + assert user.name is MISSING + assert user.age is MISSING + + user = self.round_trip_to_container(module.User("Bond", 7)) + assert isinstance(user, module.User) + assert type(user) is module.User + assert user.name == "Bond" + assert user.age == 7 + + def test_nested(self, module: Any) -> None: + data = self.round_trip_to_container({1: module.User()}) + user = data[1] + assert isinstance(user, module.User) + assert type(user) is module.User + assert user.name is MISSING + assert user.age is MISSING + + data = self.round_trip_to_container({1: module.User("Bond", 7)}) + user = data[1] + assert isinstance(user, module.User) + assert type(user) is module.User + assert user.name == "Bond" + assert user.age == 7 + + def test_list(self, module: Any) -> None: + lst = self.round_trip_to_container(module.UserList) + assert isinstance(lst, module.UserList) + assert type(lst) is module.UserList + # assert lst.list is MISSING # fails: lst.list is "???" + assert lst.list == MISSING + + lst = self.round_trip_to_container( + module.UserList([module.User("Bond", 7)]) + ) + assert isinstance(lst, module.UserList) + assert type(lst) is module.UserList + assert len(lst.list) == 1 + user = lst.list[0] + assert isinstance(user, module.User) + assert type(user) is module.User + assert user.name == "Bond" + assert user.age == 7 + + def test_dict(self, module: Any) -> None: + user_dict = self.round_trip_to_container(module.UserDict) + assert isinstance(user_dict, module.UserDict) + assert type(user_dict) is module.UserDict + # assert user_dict.dict is MISSING # fails: dct.dict is "???" + assert user_dict.dict == MISSING + + user_dict = self.round_trip_to_container( + module.UserDict({"user007": module.User("Bond", 7)}) + ) + assert isinstance(user_dict, module.UserDict) + assert type(user_dict) is module.UserDict + assert len(user_dict.dict) == 1 + user = user_dict.dict["user007"] + assert isinstance(user, module.User) + assert type(user) is module.User + assert user.name == "Bond" + assert user.age == 7 + + def validate_frozen_impl(conf: DictConfig) -> None: with pytest.raises(ReadonlyConfigError): @@ -1040,10 +1120,12 @@ def test_str2str_with_field(self, class_type: str) -> None: cfg[Color.RED] = "fail" data = OmegaConf.to_container(cfg, instantiate_structured_configs=True) - assert type(data) == module.DictSubclass.Str2StrWithField + assert isinstance(data, module.DictSubclass.Str2StrWithField) + assert type(data) is module.DictSubclass.Str2StrWithField assert data.foo == "bar" assert data["hello"] == "world" + class TestErrors: def test_usr2str(self, class_type: str) -> None: module: Any = import_module(class_type) From a3676e34d0f1da723748b5b260b5f983360f5ad5 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Wed, 23 Dec 2020 10:12:29 -0600 Subject: [PATCH 09/85] test instantiation of subclass of Dict[str, User] --- tests/structured_conf/test_structured_config.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/structured_conf/test_structured_config.py b/tests/structured_conf/test_structured_config.py index 1130ccb4f..2007dc9d3 100644 --- a/tests/structured_conf/test_structured_config.py +++ b/tests/structured_conf/test_structured_config.py @@ -1109,6 +1109,11 @@ def test_str2user(self, class_type: str) -> None: # bad key cfg[Color.BLUE] = "nope" + data = OmegaConf.to_container(cfg, instantiate_structured_configs=True) + assert type(data) is module.DictSubclass.Str2User + assert type(data["bond"]) is module.User + assert data["bond"] == module.User("James Bond", 7) + def test_str2str_with_field(self, class_type: str) -> None: module: Any = import_module(class_type) cfg = OmegaConf.structured(module.DictSubclass.Str2StrWithField()) From c5f5e7341f988c2177c52e661433503a12a9d3a0 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Wed, 23 Dec 2020 10:36:03 -0600 Subject: [PATCH 10/85] test instantiate_structured_configs-Str2UserWithField --- tests/structured_conf/data/attr_classes.py | 4 +++ tests/structured_conf/data/dataclasses.py | 4 +++ .../structured_conf/test_structured_config.py | 30 +++++++++++++++++-- 3 files changed, 36 insertions(+), 2 deletions(-) diff --git a/tests/structured_conf/data/attr_classes.py b/tests/structured_conf/data/attr_classes.py index d7a8d5626..76cd23cee 100644 --- a/tests/structured_conf/data/attr_classes.py +++ b/tests/structured_conf/data/attr_classes.py @@ -434,6 +434,10 @@ class Str2StrWithField(Dict[str, str]): class Str2IntWithStrField(Dict[str, int]): foo: int = 1 + @attr.s(auto_attribs=True) + class Str2UserWithField(Dict[str, User]): + foo: User = User("Bond", 7) + class Error: @attr.s(auto_attribs=True) class User2Str(Dict[User, str]): diff --git a/tests/structured_conf/data/dataclasses.py b/tests/structured_conf/data/dataclasses.py index 0e8f99b64..3bbfa990d 100644 --- a/tests/structured_conf/data/dataclasses.py +++ b/tests/structured_conf/data/dataclasses.py @@ -453,6 +453,10 @@ class Str2StrWithField(Dict[str, str]): class Str2IntWithStrField(Dict[str, int]): foo: int = 1 + @dataclass + class Str2UserWithField(Dict[str, User]): + foo: User = User("Bond", 7) + class Error: @dataclass class User2Str(Dict[User, str]): diff --git a/tests/structured_conf/test_structured_config.py b/tests/structured_conf/test_structured_config.py index 2007dc9d3..e1ed6197c 100644 --- a/tests/structured_conf/test_structured_config.py +++ b/tests/structured_conf/test_structured_config.py @@ -944,7 +944,6 @@ def test_dict(self, module: Any) -> None: assert user.age == 7 - def validate_frozen_impl(conf: DictConfig) -> None: with pytest.raises(ReadonlyConfigError): conf.x = 20 @@ -1114,6 +1113,34 @@ def test_str2user(self, class_type: str) -> None: assert type(data["bond"]) is module.User assert data["bond"] == module.User("James Bond", 7) + def test_str2user_with_field(self, class_type: str) -> None: + module: Any = import_module(class_type) + cfg = OmegaConf.structured(module.DictSubclass.Str2UserWithField()) + + assert cfg.foo.name == "Bond" + assert cfg.foo.age == 7 + assert isinstance(cfg.foo, DictConfig) + + cfg.mp = module.User(name="Moneypenny", age=11) + assert cfg.mp.name == "Moneypenny" + assert cfg.mp.age == 11 + assert isinstance(cfg.mp, DictConfig) + + with pytest.raises(ValidationError): + # bad value + cfg.hello = "world" + + with pytest.raises(KeyValidationError): + # bad key + cfg[Color.BLUE] = "nope" + + data = OmegaConf.to_container(cfg, instantiate_structured_configs=True) + assert type(data) is module.DictSubclass.Str2UserWithField + assert type(data.foo) is module.User + assert data.foo == module.User("Bond", 7) + assert type(data["mp"]) is module.User + assert data["mp"] == module.User("Moneypenny", 11) + def test_str2str_with_field(self, class_type: str) -> None: module: Any = import_module(class_type) cfg = OmegaConf.structured(module.DictSubclass.Str2StrWithField()) @@ -1130,7 +1157,6 @@ def test_str2str_with_field(self, class_type: str) -> None: assert data.foo == "bar" assert data["hello"] == "world" - class TestErrors: def test_usr2str(self, class_type: str) -> None: module: Any = import_module(class_type) From abf8af088b381ea810fdbbaa035937f512c318f4 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Mon, 28 Dec 2020 20:29:08 -0600 Subject: [PATCH 11/85] fix bug with allow_objects flag --- omegaconf/basecontainer.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/omegaconf/basecontainer.py b/omegaconf/basecontainer.py index 3d988ac0c..e30304db4 100644 --- a/omegaconf/basecontainer.py +++ b/omegaconf/basecontainer.py @@ -244,10 +244,14 @@ def convert(val: Node) -> Any: else: retdict[key] = convert(node) - def _instantiate_structured_config_impl(retdict, object_type): + def _instantiate_structured_config_impl( + retdict, object_type, allow_objects + ): from ._utils import get_structured_config_data - object_type_field_names = get_structured_config_data(object_type).keys() + object_type_field_names = get_structured_config_data( + object_type, allow_objects=allow_objects + ).keys() if issubclass(object_type, dict): # Extending dict as a subclass @@ -284,7 +288,9 @@ def _instantiate_structured_config_impl(retdict, object_type): if instantiate_structured_configs and ( is_structured_config(ref_type) or is_structured_config(object_type) ): - retdict = _instantiate_structured_config_impl(retdict, object_type) + retdict = _instantiate_structured_config_impl( + retdict, object_type, conf._get_flag("allow_objects") + ) return retdict elif isinstance(conf, ListConfig): From a54cb9895b0902e0014661a39ffb2947d9de6ce4 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Thu, 21 Jan 2021 15:50:27 -0600 Subject: [PATCH 12/85] add comment --- omegaconf/basecontainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/omegaconf/basecontainer.py b/omegaconf/basecontainer.py index e30304db4..758e22aa7 100644 --- a/omegaconf/basecontainer.py +++ b/omegaconf/basecontainer.py @@ -266,6 +266,7 @@ def _instantiate_structured_config_impl( result = object_type(**retdict_field_items) result.update(retdict_nonfield_items) else: + # normal structured config assert set(retdict.keys()) <= set(object_type_field_names) result = object_type(**retdict) return result From 5fdd0194d56ec4cc0cacf7cfe4f7061e35f3b54e Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Thu, 21 Jan 2021 15:57:28 -0600 Subject: [PATCH 13/85] remove dependence on ref_type; use object_type --- omegaconf/basecontainer.py | 18 +----------------- 1 file changed, 1 insertion(+), 17 deletions(-) diff --git a/omegaconf/basecontainer.py b/omegaconf/basecontainer.py index 758e22aa7..394c40cae 100644 --- a/omegaconf/basecontainer.py +++ b/omegaconf/basecontainer.py @@ -271,24 +271,8 @@ def _instantiate_structured_config_impl( result = object_type(**retdict) return result - ref_type = conf._metadata.ref_type object_type = conf._metadata.object_type - # I think that: - # ref_type should be either the type annotation for the value (set by e.g. - # a dataclass field type annotation or a typing.Dict type annotation) or, - # if annotation is available, the type of the value. - # object_type (set in dictconfig.DictConfig._set_value_impl) is the type of - # the value, used, possibly a subclass of ref_type. - if is_structured_config(ref_type): - assert is_structured_config(object_type) - if is_structured_config(ref_type) or is_structured_config(object_type): - assert ref_type is not None - assert object_type is not None - if ref_type is not Any: - assert issubclass(object_type, ref_type) - if instantiate_structured_configs and ( - is_structured_config(ref_type) or is_structured_config(object_type) - ): + if instantiate_structured_configs and is_structured_config(object_type): retdict = _instantiate_structured_config_impl( retdict, object_type, conf._get_flag("allow_objects") ) From fc49cbdd2f5681d4512abc82f66f5520bfce0203 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Thu, 21 Jan 2021 15:58:11 -0600 Subject: [PATCH 14/85] use keyword args in call to _instantiate_structured_config_impl --- omegaconf/basecontainer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/omegaconf/basecontainer.py b/omegaconf/basecontainer.py index 394c40cae..3668d90b1 100644 --- a/omegaconf/basecontainer.py +++ b/omegaconf/basecontainer.py @@ -274,7 +274,9 @@ def _instantiate_structured_config_impl( object_type = conf._metadata.object_type if instantiate_structured_configs and is_structured_config(object_type): retdict = _instantiate_structured_config_impl( - retdict, object_type, conf._get_flag("allow_objects") + retdict=retdict, + object_type=object_type, + allow_objects=conf._get_flag("allow_objects"), ) return retdict From 64b3ed736eaca3996a430181249e12d2c3992a2c Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Thu, 21 Jan 2021 16:00:05 -0600 Subject: [PATCH 15/85] refactor for clearer control flow --- omegaconf/basecontainer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/omegaconf/basecontainer.py b/omegaconf/basecontainer.py index 3668d90b1..9f98a33b3 100644 --- a/omegaconf/basecontainer.py +++ b/omegaconf/basecontainer.py @@ -273,13 +273,14 @@ def _instantiate_structured_config_impl( object_type = conf._metadata.object_type if instantiate_structured_configs and is_structured_config(object_type): - retdict = _instantiate_structured_config_impl( + retstruct = _instantiate_structured_config_impl( retdict=retdict, object_type=object_type, allow_objects=conf._get_flag("allow_objects"), ) - - return retdict + return retstruct + else: + return retdict elif isinstance(conf, ListConfig): retlist: List[Any] = [] for index in range(len(conf)): From 41588c1b4245ec1897d4b485e08bfba578229ad1 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Sat, 23 Jan 2021 18:11:26 -0600 Subject: [PATCH 16/85] move method _instantiate_structured_config_impl --- omegaconf/basecontainer.py | 54 +++++++++++++++++++------------------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/omegaconf/basecontainer.py b/omegaconf/basecontainer.py index 9f98a33b3..84ca69fc6 100644 --- a/omegaconf/basecontainer.py +++ b/omegaconf/basecontainer.py @@ -244,33 +244,6 @@ def convert(val: Node) -> Any: else: retdict[key] = convert(node) - def _instantiate_structured_config_impl( - retdict, object_type, allow_objects - ): - from ._utils import get_structured_config_data - - object_type_field_names = get_structured_config_data( - object_type, allow_objects=allow_objects - ).keys() - if issubclass(object_type, dict): - # Extending dict as a subclass - - retdict_field_items = { - k: v for k, v in retdict.items() if k in object_type_field_names - } - retdict_nonfield_items = { - k: v - for k, v in retdict.items() - if k not in object_type_field_names - } - result = object_type(**retdict_field_items) - result.update(retdict_nonfield_items) - else: - # normal structured config - assert set(retdict.keys()) <= set(object_type_field_names) - result = object_type(**retdict) - return result - object_type = conf._metadata.object_type if instantiate_structured_configs and is_structured_config(object_type): retstruct = _instantiate_structured_config_impl( @@ -843,3 +816,30 @@ def _update_types(node: Node, ref_type: type, object_type: Optional[type]) -> No if new_ref_type is not Any: node._metadata.ref_type = new_ref_type node._metadata.optional = new_is_optional + +def _instantiate_structured_config_impl( + retdict, object_type, allow_objects +): + from ._utils import get_structured_config_data + + object_type_field_names = get_structured_config_data( + object_type, allow_objects=allow_objects + ).keys() + if issubclass(object_type, dict): + # Extending dict as a subclass + + retdict_field_items = { + k: v for k, v in retdict.items() if k in object_type_field_names + } + retdict_nonfield_items = { + k: v + for k, v in retdict.items() + if k not in object_type_field_names + } + result = object_type(**retdict_field_items) + result.update(retdict_nonfield_items) + else: + # normal structured config + assert set(retdict.keys()) <= set(object_type_field_names) + result = object_type(**retdict) + return result From 653a54eca98a12f4c63b890bf948de38ac9b8cd3 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Sat, 30 Jan 2021 16:58:41 -0600 Subject: [PATCH 17/85] fix lint/mypy errors --- omegaconf/basecontainer.py | 12 ++++++------ tests/structured_conf/test_structured_config.py | 2 ++ 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/omegaconf/basecontainer.py b/omegaconf/basecontainer.py index 84ca69fc6..9edbb4412 100644 --- a/omegaconf/basecontainer.py +++ b/omegaconf/basecontainer.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from enum import Enum from textwrap import dedent -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union import yaml @@ -246,6 +246,7 @@ def convert(val: Node) -> Any: object_type = conf._metadata.object_type if instantiate_structured_configs and is_structured_config(object_type): + assert object_type is not None retstruct = _instantiate_structured_config_impl( retdict=retdict, object_type=object_type, @@ -817,9 +818,10 @@ def _update_types(node: Node, ref_type: type, object_type: Optional[type]) -> No node._metadata.ref_type = new_ref_type node._metadata.optional = new_is_optional + def _instantiate_structured_config_impl( - retdict, object_type, allow_objects -): + retdict: Dict[str, Any], object_type: Type[Any], allow_objects: Optional[bool] +) -> Any: from ._utils import get_structured_config_data object_type_field_names = get_structured_config_data( @@ -832,9 +834,7 @@ def _instantiate_structured_config_impl( k: v for k, v in retdict.items() if k in object_type_field_names } retdict_nonfield_items = { - k: v - for k, v in retdict.items() - if k not in object_type_field_names + k: v for k, v in retdict.items() if k not in object_type_field_names } result = object_type(**retdict_field_items) result.update(retdict_nonfield_items) diff --git a/tests/structured_conf/test_structured_config.py b/tests/structured_conf/test_structured_config.py index e1ed6197c..822df301f 100644 --- a/tests/structured_conf/test_structured_config.py +++ b/tests/structured_conf/test_structured_config.py @@ -1109,6 +1109,7 @@ def test_str2user(self, class_type: str) -> None: cfg[Color.BLUE] = "nope" data = OmegaConf.to_container(cfg, instantiate_structured_configs=True) + assert isinstance(data, module.DictSubclass.Str2User) assert type(data) is module.DictSubclass.Str2User assert type(data["bond"]) is module.User assert data["bond"] == module.User("James Bond", 7) @@ -1135,6 +1136,7 @@ def test_str2user_with_field(self, class_type: str) -> None: cfg[Color.BLUE] = "nope" data = OmegaConf.to_container(cfg, instantiate_structured_configs=True) + assert isinstance(data, module.DictSubclass.Str2UserWithField) assert type(data) is module.DictSubclass.Str2UserWithField assert type(data.foo) is module.User assert data.foo == module.User("Bond", 7) From 65497eef0852392a7d3aa24e4c8b2313991bf870 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Sat, 30 Jan 2021 23:59:55 -0600 Subject: [PATCH 18/85] remove unnecessary allow_objects flag --- omegaconf/basecontainer.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/omegaconf/basecontainer.py b/omegaconf/basecontainer.py index 9edbb4412..8469daf15 100644 --- a/omegaconf/basecontainer.py +++ b/omegaconf/basecontainer.py @@ -248,9 +248,7 @@ def convert(val: Node) -> Any: if instantiate_structured_configs and is_structured_config(object_type): assert object_type is not None retstruct = _instantiate_structured_config_impl( - retdict=retdict, - object_type=object_type, - allow_objects=conf._get_flag("allow_objects"), + retdict=retdict, object_type=object_type ) return retstruct else: @@ -820,13 +818,11 @@ def _update_types(node: Node, ref_type: type, object_type: Optional[type]) -> No def _instantiate_structured_config_impl( - retdict: Dict[str, Any], object_type: Type[Any], allow_objects: Optional[bool] + retdict: Dict[str, Any], object_type: Type[Any] ) -> Any: from ._utils import get_structured_config_data - object_type_field_names = get_structured_config_data( - object_type, allow_objects=allow_objects - ).keys() + object_type_field_names = get_structured_config_data(object_type).keys() if issubclass(object_type, dict): # Extending dict as a subclass From 1640f8b5292a39f292a0d68d1264b69392a319ad Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Wed, 3 Feb 2021 13:32:24 -0600 Subject: [PATCH 19/85] rename parameter 'instantiate_structured_configs' -> 'instantiate' --- omegaconf/basecontainer.py | 10 +++++----- omegaconf/omegaconf.py | 4 ++-- tests/structured_conf/test_structured_config.py | 10 ++++------ tests/test_base_config.py | 8 +++----- 4 files changed, 14 insertions(+), 18 deletions(-) diff --git a/omegaconf/basecontainer.py b/omegaconf/basecontainer.py index 8469daf15..73ae24fb5 100644 --- a/omegaconf/basecontainer.py +++ b/omegaconf/basecontainer.py @@ -193,12 +193,12 @@ def _to_content( resolve: bool, enum_to_str: bool = False, exclude_structured_configs: bool = False, - instantiate_structured_configs: bool = False, + instantiate: bool = False, ) -> Union[None, Any, str, Dict[DictKeyType, Any], List[Any]]: from .dictconfig import DictConfig from .listconfig import ListConfig - if exclude_structured_configs and instantiate_structured_configs: + if exclude_structured_configs and instantiate: raise ValueError( "Cannot both exclude and and instantiate structured configs" ) @@ -239,13 +239,13 @@ def convert(val: Node) -> Any: resolve=resolve, enum_to_str=enum_to_str, exclude_structured_configs=exclude_structured_configs, - instantiate_structured_configs=instantiate_structured_configs, + instantiate=instantiate, ) else: retdict[key] = convert(node) object_type = conf._metadata.object_type - if instantiate_structured_configs and is_structured_config(object_type): + if instantiate and is_structured_config(object_type): assert object_type is not None retstruct = _instantiate_structured_config_impl( retdict=retdict, object_type=object_type @@ -269,7 +269,7 @@ def convert(val: Node) -> Any: resolve=resolve, enum_to_str=enum_to_str, exclude_structured_configs=exclude_structured_configs, - instantiate_structured_configs=instantiate_structured_configs, + instantiate=instantiate, ) retlist.append(item) else: diff --git a/omegaconf/omegaconf.py b/omegaconf/omegaconf.py index 4ca21fbb8..acd95a67c 100644 --- a/omegaconf/omegaconf.py +++ b/omegaconf/omegaconf.py @@ -580,7 +580,7 @@ def to_container( resolve: bool = False, enum_to_str: bool = False, exclude_structured_configs: bool = False, - instantiate_structured_configs: bool = False, + instantiate: bool = False, ) -> Union[Dict[DictKeyType, Any], List[Any], None, str]: """ Resursively converts an OmegaConf config to a primitive container (dict or list). @@ -601,7 +601,7 @@ def to_container( resolve=resolve, enum_to_str=enum_to_str, exclude_structured_configs=exclude_structured_configs, - instantiate_structured_configs=instantiate_structured_configs, + instantiate=instantiate, ) @staticmethod diff --git a/tests/structured_conf/test_structured_config.py b/tests/structured_conf/test_structured_config.py index 822df301f..d5bc41621 100644 --- a/tests/structured_conf/test_structured_config.py +++ b/tests/structured_conf/test_structured_config.py @@ -872,9 +872,7 @@ def module(self, class_type: str) -> Any: def round_trip_to_container(self, input_data: Any) -> Any: serialized = OmegaConf.create(input_data) - round_tripped = OmegaConf.to_container( - serialized, instantiate_structured_configs=True - ) + round_tripped = OmegaConf.to_container(serialized, instantiate=True) return round_tripped def test_basic(self, module: Any) -> None: @@ -1108,7 +1106,7 @@ def test_str2user(self, class_type: str) -> None: # bad key cfg[Color.BLUE] = "nope" - data = OmegaConf.to_container(cfg, instantiate_structured_configs=True) + data = OmegaConf.to_container(cfg, instantiate=True) assert isinstance(data, module.DictSubclass.Str2User) assert type(data) is module.DictSubclass.Str2User assert type(data["bond"]) is module.User @@ -1135,7 +1133,7 @@ def test_str2user_with_field(self, class_type: str) -> None: # bad key cfg[Color.BLUE] = "nope" - data = OmegaConf.to_container(cfg, instantiate_structured_configs=True) + data = OmegaConf.to_container(cfg, instantiate=True) assert isinstance(data, module.DictSubclass.Str2UserWithField) assert type(data) is module.DictSubclass.Str2UserWithField assert type(data.foo) is module.User @@ -1153,7 +1151,7 @@ def test_str2str_with_field(self, class_type: str) -> None: with pytest.raises(KeyValidationError): cfg[Color.RED] = "fail" - data = OmegaConf.to_container(cfg, instantiate_structured_configs=True) + data = OmegaConf.to_container(cfg, instantiate=True) assert isinstance(data, module.DictSubclass.Str2StrWithField) assert type(data) is module.DictSubclass.Str2StrWithField assert data.foo == "bar" diff --git a/tests/test_base_config.py b/tests/test_base_config.py index be9ef6031..fe7ef3185 100644 --- a/tests/test_base_config.py +++ b/tests/test_base_config.py @@ -135,7 +135,7 @@ def test_exclude_structured_configs(cfg: Any, ex_false: Any, ex_true: Any) -> No ret1 = OmegaConf.to_container(cfg, exclude_structured_configs=False) assert ret1 == ex_false - ret1 = OmegaConf.to_container(cfg, instantiate_structured_configs=True) + ret1 = OmegaConf.to_container(cfg, instantiate=True) assert ret1 == ex_true ret1 = OmegaConf.to_container(cfg, exclude_structured_configs=True) @@ -207,7 +207,7 @@ def test_to_container(src: Any, expected: Any, expected_with_resolve: Any) -> No cfg = OmegaConf.create(src) container = OmegaConf.to_container(cfg) assert container == expected - container = OmegaConf.to_container(cfg, instantiate_structured_configs=True) + container = OmegaConf.to_container(cfg, instantiate=True) assert container == expected container = OmegaConf.to_container(cfg, resolve=True) assert container == expected_with_resolve @@ -224,9 +224,7 @@ def test_to_container_invalid_input() -> None: def test_to_container_options_mutually_exclusive() -> None: with raises(ValueError): cfg = OmegaConf.create() - OmegaConf.to_container( - cfg, exclude_structured_configs=True, instantiate_structured_configs=True - ) + OmegaConf.to_container(cfg, exclude_structured_configs=True, instantiate=True) def test_string_interpolation_with_readonly_parent() -> None: From 7e52b24e12b7d635b079cb381b628672ceaabfed Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Wed, 3 Feb 2021 13:43:49 -0600 Subject: [PATCH 20/85] create OmegaConf.to_object alias for OmegaConf.to_container --- omegaconf/omegaconf.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/omegaconf/omegaconf.py b/omegaconf/omegaconf.py index acd95a67c..a543a530b 100644 --- a/omegaconf/omegaconf.py +++ b/omegaconf/omegaconf.py @@ -589,6 +589,9 @@ def to_container( :param enum_to_str: True to convert Enum values to strings :param exclude_structured_configs: If True, do not convert Structured Configs (DictConfigs backed by a dataclass) + :param instantiate: If True, this function will instantiate structured configs + (DictConfigs backed by a dataclass), by creating an instance + of the underlying dataclass. See also OmegaConf.to_object. :return: A dict or a list representing this config as a primitive container. """ if not OmegaConf.is_config(cfg): @@ -604,6 +607,33 @@ def to_container( instantiate=instantiate, ) + @staticmethod + def to_object( + cfg: Any, + *, + resolve: bool = False, + enum_to_str: bool = False, + ) -> Union[Dict[DictKeyType, Any], List[Any], None, str, Any]: + """ + Resursively converts an OmegaConf config to a primitive container (dict or list). + Any DictConfig objects backed by dataclasses or attrs classes are instantiated + as instances of those backing classes. + + This is an alias for OmegaConf.to_container(..., exclude_structured_configs=Flase, instantiate=True) + + :param cfg: the config to convert + :param resolve: True to resolve all values + :param enum_to_str: True to convert Enum values to strings + :return: A dict or a list or dataclass representing this config. + """ + return OmegaConf.to_container( + cfg=cfg, + resolve=resolve, + enum_to_str=enum_to_str, + exclude_structured_configs=False, + instantiate=True, + ) + @staticmethod def is_missing(cfg: Any, key: DictKeyType) -> bool: assert isinstance(cfg, Container) From 8dfd397d32a0f4d5c6bd7a53f6e3f6fa549e192e Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Wed, 3 Feb 2021 13:47:42 -0600 Subject: [PATCH 21/85] One use case per test --- .../structured_conf/test_structured_config.py | 46 ++++++++++--------- 1 file changed, 25 insertions(+), 21 deletions(-) diff --git a/tests/structured_conf/test_structured_config.py b/tests/structured_conf/test_structured_config.py index d5bc41621..08dfe5944 100644 --- a/tests/structured_conf/test_structured_config.py +++ b/tests/structured_conf/test_structured_config.py @@ -876,26 +876,20 @@ def round_trip_to_container(self, input_data: Any) -> Any: return round_tripped def test_basic(self, module: Any) -> None: - user = self.round_trip_to_container(module.User()) - assert isinstance(user, module.User) - assert type(user) is module.User - assert user.name is MISSING - assert user.age is MISSING - user = self.round_trip_to_container(module.User("Bond", 7)) assert isinstance(user, module.User) assert type(user) is module.User assert user.name == "Bond" assert user.age == 7 - def test_nested(self, module: Any) -> None: - data = self.round_trip_to_container({1: module.User()}) - user = data[1] + def test_basic_with_missing(self, module: Any) -> None: + user = self.round_trip_to_container(module.User()) assert isinstance(user, module.User) assert type(user) is module.User assert user.name is MISSING assert user.age is MISSING + def test_nested(self, module: Any) -> None: data = self.round_trip_to_container({1: module.User("Bond", 7)}) user = data[1] assert isinstance(user, module.User) @@ -903,13 +897,15 @@ def test_nested(self, module: Any) -> None: assert user.name == "Bond" assert user.age == 7 - def test_list(self, module: Any) -> None: - lst = self.round_trip_to_container(module.UserList) - assert isinstance(lst, module.UserList) - assert type(lst) is module.UserList - # assert lst.list is MISSING # fails: lst.list is "???" - assert lst.list == MISSING + def test_nested_with_missing(self, module: Any) -> None: + data = self.round_trip_to_container({1: module.User()}) + user = data[1] + assert isinstance(user, module.User) + assert type(user) is module.User + assert user.name is MISSING + assert user.age is MISSING + def test_list(self, module: Any) -> None: lst = self.round_trip_to_container( module.UserList([module.User("Bond", 7)]) ) @@ -922,13 +918,14 @@ def test_list(self, module: Any) -> None: assert user.name == "Bond" assert user.age == 7 - def test_dict(self, module: Any) -> None: - user_dict = self.round_trip_to_container(module.UserDict) - assert isinstance(user_dict, module.UserDict) - assert type(user_dict) is module.UserDict - # assert user_dict.dict is MISSING # fails: dct.dict is "???" - assert user_dict.dict == MISSING + def test_list_with_missing(self, module: Any) -> None: + lst = self.round_trip_to_container(module.UserList) + assert isinstance(lst, module.UserList) + assert type(lst) is module.UserList + # assert lst.list is MISSING # fails: lst.list is "???" + assert lst.list == MISSING + def test_dict(self, module: Any) -> None: user_dict = self.round_trip_to_container( module.UserDict({"user007": module.User("Bond", 7)}) ) @@ -941,6 +938,13 @@ def test_dict(self, module: Any) -> None: assert user.name == "Bond" assert user.age == 7 + def test_dict_with_missing(self, module: Any) -> None: + user_dict = self.round_trip_to_container(module.UserDict) + assert isinstance(user_dict, module.UserDict) + assert type(user_dict) is module.UserDict + # assert user_dict.dict is MISSING # fails: dct.dict is "???" + assert user_dict.dict == MISSING + def validate_frozen_impl(conf: DictConfig) -> None: with pytest.raises(ReadonlyConfigError): From 68b1f74b70d419597cd3704137e64410dcbf3472 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Wed, 3 Feb 2021 13:58:02 -0600 Subject: [PATCH 22/85] coverage: use to_object(cfg) instead of to_container(object, instantiate=True) --- tests/structured_conf/test_structured_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/structured_conf/test_structured_config.py b/tests/structured_conf/test_structured_config.py index 08dfe5944..32fffc004 100644 --- a/tests/structured_conf/test_structured_config.py +++ b/tests/structured_conf/test_structured_config.py @@ -872,7 +872,7 @@ def module(self, class_type: str) -> Any: def round_trip_to_container(self, input_data: Any) -> Any: serialized = OmegaConf.create(input_data) - round_tripped = OmegaConf.to_container(serialized, instantiate=True) + round_tripped = OmegaConf.to_object(serialized) return round_tripped def test_basic(self, module: Any) -> None: From 5b470497de1f930213c5d4a44ef7a4cf9bcdfd62 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Wed, 3 Feb 2021 13:59:23 -0600 Subject: [PATCH 23/85] rename tests: to_object instead of to_container --- .../structured_conf/test_structured_config.py | 20 +++++++++---------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/tests/structured_conf/test_structured_config.py b/tests/structured_conf/test_structured_config.py index 32fffc004..ae10cdea0 100644 --- a/tests/structured_conf/test_structured_config.py +++ b/tests/structured_conf/test_structured_config.py @@ -870,27 +870,27 @@ def module(self, class_type: str) -> Any: module: Any = import_module(class_type) return module - def round_trip_to_container(self, input_data: Any) -> Any: + def round_trip_to_object(self, input_data: Any) -> Any: serialized = OmegaConf.create(input_data) round_tripped = OmegaConf.to_object(serialized) return round_tripped def test_basic(self, module: Any) -> None: - user = self.round_trip_to_container(module.User("Bond", 7)) + user = self.round_trip_to_object(module.User("Bond", 7)) assert isinstance(user, module.User) assert type(user) is module.User assert user.name == "Bond" assert user.age == 7 def test_basic_with_missing(self, module: Any) -> None: - user = self.round_trip_to_container(module.User()) + user = self.round_trip_to_object(module.User()) assert isinstance(user, module.User) assert type(user) is module.User assert user.name is MISSING assert user.age is MISSING def test_nested(self, module: Any) -> None: - data = self.round_trip_to_container({1: module.User("Bond", 7)}) + data = self.round_trip_to_object({1: module.User("Bond", 7)}) user = data[1] assert isinstance(user, module.User) assert type(user) is module.User @@ -898,7 +898,7 @@ def test_nested(self, module: Any) -> None: assert user.age == 7 def test_nested_with_missing(self, module: Any) -> None: - data = self.round_trip_to_container({1: module.User()}) + data = self.round_trip_to_object({1: module.User()}) user = data[1] assert isinstance(user, module.User) assert type(user) is module.User @@ -906,9 +906,7 @@ def test_nested_with_missing(self, module: Any) -> None: assert user.age is MISSING def test_list(self, module: Any) -> None: - lst = self.round_trip_to_container( - module.UserList([module.User("Bond", 7)]) - ) + lst = self.round_trip_to_object(module.UserList([module.User("Bond", 7)])) assert isinstance(lst, module.UserList) assert type(lst) is module.UserList assert len(lst.list) == 1 @@ -919,14 +917,14 @@ def test_list(self, module: Any) -> None: assert user.age == 7 def test_list_with_missing(self, module: Any) -> None: - lst = self.round_trip_to_container(module.UserList) + lst = self.round_trip_to_object(module.UserList) assert isinstance(lst, module.UserList) assert type(lst) is module.UserList # assert lst.list is MISSING # fails: lst.list is "???" assert lst.list == MISSING def test_dict(self, module: Any) -> None: - user_dict = self.round_trip_to_container( + user_dict = self.round_trip_to_object( module.UserDict({"user007": module.User("Bond", 7)}) ) assert isinstance(user_dict, module.UserDict) @@ -939,7 +937,7 @@ def test_dict(self, module: Any) -> None: assert user.age == 7 def test_dict_with_missing(self, module: Any) -> None: - user_dict = self.round_trip_to_container(module.UserDict) + user_dict = self.round_trip_to_object(module.UserDict) assert isinstance(user_dict, module.UserDict) assert type(user_dict) is module.UserDict # assert user_dict.dict is MISSING # fails: dct.dict is "???" From 15324e97cbbc942c9169ea31835f7499c42ba458 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Wed, 3 Feb 2021 14:00:35 -0600 Subject: [PATCH 24/85] tests: user str key instead of int key --- tests/structured_conf/test_structured_config.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/structured_conf/test_structured_config.py b/tests/structured_conf/test_structured_config.py index ae10cdea0..e593643be 100644 --- a/tests/structured_conf/test_structured_config.py +++ b/tests/structured_conf/test_structured_config.py @@ -890,16 +890,16 @@ def test_basic_with_missing(self, module: Any) -> None: assert user.age is MISSING def test_nested(self, module: Any) -> None: - data = self.round_trip_to_object({1: module.User("Bond", 7)}) - user = data[1] + data = self.round_trip_to_object({"user": module.User("Bond", 7)}) + user = data["user"] assert isinstance(user, module.User) assert type(user) is module.User assert user.name == "Bond" assert user.age == 7 def test_nested_with_missing(self, module: Any) -> None: - data = self.round_trip_to_object({1: module.User()}) - user = data[1] + data = self.round_trip_to_object({"user": module.User()}) + user = data["user"] assert isinstance(user, module.User) assert type(user) is module.User assert user.name is MISSING From 375babf0f9cf79359cd0bd561c8883c49570ee14 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Wed, 3 Feb 2021 14:05:44 -0600 Subject: [PATCH 25/85] tests: change 'assert ... is MISSING' -> 'assert ... == MISSING' --- tests/structured_conf/test_structured_config.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/structured_conf/test_structured_config.py b/tests/structured_conf/test_structured_config.py index e593643be..f7ab9643c 100644 --- a/tests/structured_conf/test_structured_config.py +++ b/tests/structured_conf/test_structured_config.py @@ -886,8 +886,8 @@ def test_basic_with_missing(self, module: Any) -> None: user = self.round_trip_to_object(module.User()) assert isinstance(user, module.User) assert type(user) is module.User - assert user.name is MISSING - assert user.age is MISSING + assert user.name == MISSING + assert user.age == MISSING def test_nested(self, module: Any) -> None: data = self.round_trip_to_object({"user": module.User("Bond", 7)}) @@ -902,8 +902,8 @@ def test_nested_with_missing(self, module: Any) -> None: user = data["user"] assert isinstance(user, module.User) assert type(user) is module.User - assert user.name is MISSING - assert user.age is MISSING + assert user.name == MISSING + assert user.age == MISSING def test_list(self, module: Any) -> None: lst = self.round_trip_to_object(module.UserList([module.User("Bond", 7)])) @@ -920,7 +920,6 @@ def test_list_with_missing(self, module: Any) -> None: lst = self.round_trip_to_object(module.UserList) assert isinstance(lst, module.UserList) assert type(lst) is module.UserList - # assert lst.list is MISSING # fails: lst.list is "???" assert lst.list == MISSING def test_dict(self, module: Any) -> None: @@ -940,7 +939,6 @@ def test_dict_with_missing(self, module: Any) -> None: user_dict = self.round_trip_to_object(module.UserDict) assert isinstance(user_dict, module.UserDict) assert type(user_dict) is module.UserDict - # assert user_dict.dict is MISSING # fails: dct.dict is "???" assert user_dict.dict == MISSING From 6422f39fb5ac07d89a7ee8ad09dc7654f849e4a0 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Wed, 3 Feb 2021 14:19:14 -0600 Subject: [PATCH 26/85] add tests for object nested inside object --- .../structured_conf/test_structured_config.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/structured_conf/test_structured_config.py b/tests/structured_conf/test_structured_config.py index f7ab9643c..7b5d0fbf8 100644 --- a/tests/structured_conf/test_structured_config.py +++ b/tests/structured_conf/test_structured_config.py @@ -941,6 +941,26 @@ def test_dict_with_missing(self, module: Any) -> None: assert type(user_dict) is module.UserDict assert user_dict.dict == MISSING + def test_nested_object(self, module: Any) -> None: + nested = self.round_trip_to_object(module.NestedConfig) + assert isinstance(nested, module.NestedConfig) + assert type(nested) is module.NestedConfig + + assert nested.default_value == MISSING + + assert isinstance(nested.user_provided_default, module.Nested) + assert type(nested.user_provided_default) is module.Nested + assert nested.user_provided_default.with_default == 42 + + def test_nested_object_with_Any_ref_type(self, module: Any) -> None: + nested = self.round_trip_to_object(module.NestedWithAny) + assert isinstance(nested, module.NestedWithAny) + assert type(nested) is module.NestedWithAny + + assert isinstance(nested.var, module.Nested) + assert type(nested.var) is module.Nested + assert nested.var.with_default == 10 + def validate_frozen_impl(conf: DictConfig) -> None: with pytest.raises(ReadonlyConfigError): From ecc05b100e7f59d4c3d94a95a00e95e5432f058f Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Wed, 3 Feb 2021 15:02:06 -0600 Subject: [PATCH 27/85] one use case per tests: dict subclass --- tests/structured_conf/test_structured_config.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/tests/structured_conf/test_structured_config.py b/tests/structured_conf/test_structured_config.py index 7b5d0fbf8..559394aa1 100644 --- a/tests/structured_conf/test_structured_config.py +++ b/tests/structured_conf/test_structured_config.py @@ -1113,8 +1113,8 @@ def test_color2color(self, class_type: str) -> None: def test_str2user(self, class_type: str) -> None: module: Any = import_module(class_type) cfg = OmegaConf.structured(module.DictSubclass.Str2User()) - cfg.bond = module.User(name="James Bond", age=7) + assert cfg.bond.name == "James Bond" assert cfg.bond.age == 7 @@ -1126,7 +1126,12 @@ def test_str2user(self, class_type: str) -> None: # bad key cfg[Color.BLUE] = "nope" + def test_str2user_instantiate(self, class_type: str) -> None: + module: Any = import_module(class_type) + cfg = OmegaConf.structured(module.DictSubclass.Str2User()) + cfg.bond = module.User(name="James Bond", age=7) data = OmegaConf.to_container(cfg, instantiate=True) + assert isinstance(data, module.DictSubclass.Str2User) assert type(data) is module.DictSubclass.Str2User assert type(data["bond"]) is module.User @@ -1153,6 +1158,10 @@ def test_str2user_with_field(self, class_type: str) -> None: # bad key cfg[Color.BLUE] = "nope" + def test_str2user_with_field_instantiate(self, class_type: str) -> None: + module: Any = import_module(class_type) + cfg = OmegaConf.structured(module.DictSubclass.Str2UserWithField()) + cfg.mp = module.User(name="Moneypenny", age=11) data = OmegaConf.to_container(cfg, instantiate=True) assert isinstance(data, module.DictSubclass.Str2UserWithField) assert type(data) is module.DictSubclass.Str2UserWithField @@ -1171,6 +1180,11 @@ def test_str2str_with_field(self, class_type: str) -> None: with pytest.raises(KeyValidationError): cfg[Color.RED] = "fail" + def test_str2str_with_field_instantiate(self, class_type: str) -> None: + module: Any = import_module(class_type) + cfg = OmegaConf.structured(module.DictSubclass.Str2StrWithField()) + cfg.hello = "world" + data = OmegaConf.to_container(cfg, instantiate=True) assert isinstance(data, module.DictSubclass.Str2StrWithField) assert type(data) is module.DictSubclass.Str2StrWithField From 9d7addce89fa25a968ce4e75cc66f87abac447af Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Wed, 3 Feb 2021 15:08:14 -0600 Subject: [PATCH 28/85] test_structured_config.py: consolidate instantiate=True tests --- .../structured_conf/test_structured_config.py | 66 +++++++++---------- 1 file changed, 32 insertions(+), 34 deletions(-) diff --git a/tests/structured_conf/test_structured_config.py b/tests/structured_conf/test_structured_config.py index 559394aa1..8d52c574b 100644 --- a/tests/structured_conf/test_structured_config.py +++ b/tests/structured_conf/test_structured_config.py @@ -961,6 +961,38 @@ def test_nested_object_with_Any_ref_type(self, module: Any) -> None: assert type(nested.var) is module.Nested assert nested.var.with_default == 10 + def test_str2user_instantiate(self, module: Any) -> None: + cfg = OmegaConf.structured(module.DictSubclass.Str2User()) + cfg.bond = module.User(name="James Bond", age=7) + data = self.round_trip_to_object(cfg) + + assert isinstance(data, module.DictSubclass.Str2User) + assert type(data) is module.DictSubclass.Str2User + assert type(data["bond"]) is module.User + assert data["bond"] == module.User("James Bond", 7) + + def test_str2user_with_field_instantiate(self, module: Any) -> None: + cfg = OmegaConf.structured(module.DictSubclass.Str2UserWithField()) + cfg.mp = module.User(name="Moneypenny", age=11) + data = self.round_trip_to_object(cfg) + + assert isinstance(data, module.DictSubclass.Str2UserWithField) + assert type(data) is module.DictSubclass.Str2UserWithField + assert type(data.foo) is module.User + assert data.foo == module.User("Bond", 7) + assert type(data["mp"]) is module.User + assert data["mp"] == module.User("Moneypenny", 11) + + def test_str2str_with_field_instantiate(self, module: Any) -> None: + cfg = OmegaConf.structured(module.DictSubclass.Str2StrWithField()) + cfg.hello = "world" + data = self.round_trip_to_object(cfg) + + assert isinstance(data, module.DictSubclass.Str2StrWithField) + assert type(data) is module.DictSubclass.Str2StrWithField + assert data.foo == "bar" + assert data["hello"] == "world" + def validate_frozen_impl(conf: DictConfig) -> None: with pytest.raises(ReadonlyConfigError): @@ -1126,17 +1158,6 @@ def test_str2user(self, class_type: str) -> None: # bad key cfg[Color.BLUE] = "nope" - def test_str2user_instantiate(self, class_type: str) -> None: - module: Any = import_module(class_type) - cfg = OmegaConf.structured(module.DictSubclass.Str2User()) - cfg.bond = module.User(name="James Bond", age=7) - data = OmegaConf.to_container(cfg, instantiate=True) - - assert isinstance(data, module.DictSubclass.Str2User) - assert type(data) is module.DictSubclass.Str2User - assert type(data["bond"]) is module.User - assert data["bond"] == module.User("James Bond", 7) - def test_str2user_with_field(self, class_type: str) -> None: module: Any = import_module(class_type) cfg = OmegaConf.structured(module.DictSubclass.Str2UserWithField()) @@ -1158,18 +1179,6 @@ def test_str2user_with_field(self, class_type: str) -> None: # bad key cfg[Color.BLUE] = "nope" - def test_str2user_with_field_instantiate(self, class_type: str) -> None: - module: Any = import_module(class_type) - cfg = OmegaConf.structured(module.DictSubclass.Str2UserWithField()) - cfg.mp = module.User(name="Moneypenny", age=11) - data = OmegaConf.to_container(cfg, instantiate=True) - assert isinstance(data, module.DictSubclass.Str2UserWithField) - assert type(data) is module.DictSubclass.Str2UserWithField - assert type(data.foo) is module.User - assert data.foo == module.User("Bond", 7) - assert type(data["mp"]) is module.User - assert data["mp"] == module.User("Moneypenny", 11) - def test_str2str_with_field(self, class_type: str) -> None: module: Any = import_module(class_type) cfg = OmegaConf.structured(module.DictSubclass.Str2StrWithField()) @@ -1180,17 +1189,6 @@ def test_str2str_with_field(self, class_type: str) -> None: with pytest.raises(KeyValidationError): cfg[Color.RED] = "fail" - def test_str2str_with_field_instantiate(self, class_type: str) -> None: - module: Any = import_module(class_type) - cfg = OmegaConf.structured(module.DictSubclass.Str2StrWithField()) - cfg.hello = "world" - - data = OmegaConf.to_container(cfg, instantiate=True) - assert isinstance(data, module.DictSubclass.Str2StrWithField) - assert type(data) is module.DictSubclass.Str2StrWithField - assert data.foo == "bar" - assert data["hello"] == "world" - class TestErrors: def test_usr2str(self, class_type: str) -> None: module: Any = import_module(class_type) From 7136baa3056bc0c45b536105a383c66187b8cc6b Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Sun, 14 Feb 2021 00:31:01 -0600 Subject: [PATCH 29/85] finish rebase against master --- tests/test_base_config.py | 172 ------------------------------------- tests/test_to_container.py | 11 +++ 2 files changed, 11 insertions(+), 172 deletions(-) diff --git a/tests/test_base_config.py b/tests/test_base_config.py index fe7ef3185..482b5e0de 100644 --- a/tests/test_base_config.py +++ b/tests/test_base_config.py @@ -80,178 +80,6 @@ def test_replace_value_node_type_with_another( assert c[key] == value._value() -@pytest.mark.parametrize( - "input_", - [ - pytest.param([1, 2, 3], id="list"), - pytest.param([1, 2, {"a": 3}], id="dict_in_list"), - pytest.param([1, 2, [10, 20]], id="list_in_list"), - pytest.param({"b": {"b": 10}}, id="dict_in_dict"), - pytest.param({"b": [False, 1, "2", 3.0, Color.RED]}, id="list_in_dict"), - pytest.param({"b": DictConfig(content=None)}, id="none_dictconfig"), - pytest.param({"b": ListConfig(content=None)}, id="none_listconfig"), - pytest.param({"b": DictConfig(content="???")}, id="missing_dictconfig"), - pytest.param({"b": ListConfig(content="???")}, id="missing_listconfig"), - ], -) -def test_to_container_returns_primitives(input_: Any) -> None: - def assert_container_with_primitives(item: Any) -> None: - if isinstance(item, list): - for v in item: - assert_container_with_primitives(v) - elif isinstance(item, dict): - for _k, v in item.items(): - assert_container_with_primitives(v) - else: - assert isinstance(item, (int, float, str, bool, type(None), Enum)) - - c = OmegaConf.create(input_) - res = OmegaConf.to_container(c, resolve=True) - assert_container_with_primitives(res) - - -@pytest.mark.parametrize( - "cfg,ex_false,ex_true", - [ - pytest.param( - {"user": User(age=7, name="Bond")}, - {"user": {"name": "Bond", "age": 7}}, - {"user": User(age=7, name="Bond")}, - ), - pytest.param( - [1, User(age=7, name="Bond")], - [1, {"name": "Bond", "age": 7}], - [1, User(age=7, name="Bond")], - ), - pytest.param( - {"users": [User(age=1, name="a"), User(age=2, name="b")]}, - {"users": [{"age": 1, "name": "a"}, {"age": 2, "name": "b"}]}, - {"users": [User(age=1, name="a"), User(age=2, name="b")]}, - ), - ], -) -def test_exclude_structured_configs(cfg: Any, ex_false: Any, ex_true: Any) -> None: - cfg = OmegaConf.create(cfg) - ret1 = OmegaConf.to_container(cfg, exclude_structured_configs=False) - assert ret1 == ex_false - - ret1 = OmegaConf.to_container(cfg, instantiate=True) - assert ret1 == ex_true - - ret1 = OmegaConf.to_container(cfg, exclude_structured_configs=True) - assert ret1 == ex_true - - -@pytest.mark.parametrize( - "src, expected, expected_with_resolve", - [ - pytest.param([], None, None, id="empty_list"), - pytest.param([1, 2, 3], None, None, id="list"), - pytest.param([None], None, None, id="list_with_none"), - pytest.param([1, "${0}", 3], None, [1, 1, 3], id="list_with_inter"), - pytest.param({}, None, None, id="empty_dict"), - pytest.param({"foo": "bar"}, None, None, id="dict"), - pytest.param( - {"foo": "${bar}", "bar": "zonk"}, - None, - {"foo": "zonk", "bar": "zonk"}, - id="dict_with_inter", - ), - pytest.param({"foo": None}, None, None, id="dict_with_none"), - pytest.param({"foo": "???"}, None, None, id="dict_missing_value"), - pytest.param({"foo": None}, None, None, id="dict_none_value"), - # containers - pytest.param( - {"foo": DictConfig(is_optional=True, content=None)}, - {"foo": None}, - None, - id="dict_none_dictconfig", - ), - pytest.param( - {"foo": DictConfig(content="???")}, - {"foo": "???"}, - None, - id="dict_missing_dictconfig", - ), - pytest.param( - {"foo": DictConfig(content="${bar}"), "bar": 10}, - {"foo": "${bar}", "bar": 10}, - {"foo": 10, "bar": 10}, - id="dict_inter_dictconfig", - ), - pytest.param( - {"foo": ListConfig(content="???")}, - {"foo": "???"}, - None, - id="dict_missing_listconfig", - ), - pytest.param( - {"foo": ListConfig(is_optional=True, content=None)}, - {"foo": None}, - None, - id="dict_none_listconfig", - ), - pytest.param( - {"foo": ListConfig(content="${bar}"), "bar": 10}, - {"foo": "${bar}", "bar": 10}, - {"foo": 10, "bar": 10}, - id="dict_inter_listconfig", - ), - ], -) -def test_to_container(src: Any, expected: Any, expected_with_resolve: Any) -> None: - if expected is None: - expected = src - if expected_with_resolve is None: - expected_with_resolve = expected - cfg = OmegaConf.create(src) - container = OmegaConf.to_container(cfg) - assert container == expected - container = OmegaConf.to_container(cfg, instantiate=True) - assert container == expected - container = OmegaConf.to_container(cfg, resolve=True) - assert container == expected_with_resolve - - -def test_to_container_invalid_input() -> None: - with pytest.raises( - ValueError, - match=re.escape("Input cfg is not an OmegaConf config object (dict)"), - ): - OmegaConf.to_container({}) - - -def test_to_container_options_mutually_exclusive() -> None: - with raises(ValueError): - cfg = OmegaConf.create() - OmegaConf.to_container(cfg, exclude_structured_configs=True, instantiate=True) - - -def test_string_interpolation_with_readonly_parent() -> None: - cfg = OmegaConf.create({"a": 10, "b": {"c": "hello_${a}"}}) - OmegaConf.set_readonly(cfg, True) - assert OmegaConf.to_container(cfg, resolve=True) == { - "a": 10, - "b": {"c": "hello_10"}, - } - - -@pytest.mark.parametrize( - "src,expected", - [ - pytest.param(DictConfig(content="${bar}"), "${bar}", id="DictConfig"), - pytest.param( - OmegaConf.create({"foo": DictConfig(content="${bar}")}), - {"foo": "${bar}"}, - id="nested_DictConfig", - ), - ], -) -def test_to_container_missing_inter_no_resolve(src: Any, expected: Any) -> None: - res = OmegaConf.to_container(src, resolve=False) - assert res == expected - - @pytest.mark.parametrize( "input_, is_empty", [ diff --git a/tests/test_to_container.py b/tests/test_to_container.py index 48c92cb58..e368c1fb8 100644 --- a/tests/test_to_container.py +++ b/tests/test_to_container.py @@ -63,6 +63,9 @@ def test_exclude_structured_configs(cfg: Any, ex_false: Any, ex_true: Any) -> No ret1 = OmegaConf.to_container(cfg, exclude_structured_configs=False) assert ret1 == ex_false + ret1 = OmegaConf.to_container(cfg, instantiate=True) + assert ret1 == ex_true + ret1 = OmegaConf.to_container(cfg, exclude_structured_configs=True) assert ret1 == ex_true @@ -132,6 +135,8 @@ def test_to_container(src: Any, expected: Any, expected_with_resolve: Any) -> No cfg = OmegaConf.create(src) container = OmegaConf.to_container(cfg) assert container == expected + container = OmegaConf.to_container(cfg, instantiate=True) + assert container == expected container = OmegaConf.to_container(cfg, resolve=True) assert container == expected_with_resolve @@ -144,6 +149,12 @@ def test_to_container_invalid_input() -> None: OmegaConf.to_container({}) +def test_to_container_options_mutually_exclusive() -> None: + with pytest.raises(ValueError): + cfg = OmegaConf.create() + OmegaConf.to_container(cfg, exclude_structured_configs=True, instantiate=True) + + def test_string_interpolation_with_readonly_parent() -> None: cfg = OmegaConf.create({"a": 10, "b": {"c": "hello_${a}"}}) OmegaConf.set_readonly(cfg, True) From 841ac01b73197e22f8195ece8c76649eb305d237 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Tue, 16 Feb 2021 18:02:52 -0600 Subject: [PATCH 30/85] Move TestInstantiateStructuredConfigs to test_to_container.py --- .../structured_conf/test_structured_config.py | 129 ---------------- tests/test_to_container.py | 140 +++++++++++++++++- 2 files changed, 139 insertions(+), 130 deletions(-) diff --git a/tests/structured_conf/test_structured_config.py b/tests/structured_conf/test_structured_config.py index 8d52c574b..80458a51a 100644 --- a/tests/structured_conf/test_structured_config.py +++ b/tests/structured_conf/test_structured_config.py @@ -864,135 +864,6 @@ def test_create_untyped_list(self, class_type: str) -> None: assert cfg.list == [1, 2] assert cfg.opt_list is None - class TestInstantiateStructuredConfigs: - @pytest.fixture - def module(self, class_type: str) -> Any: - module: Any = import_module(class_type) - return module - - def round_trip_to_object(self, input_data: Any) -> Any: - serialized = OmegaConf.create(input_data) - round_tripped = OmegaConf.to_object(serialized) - return round_tripped - - def test_basic(self, module: Any) -> None: - user = self.round_trip_to_object(module.User("Bond", 7)) - assert isinstance(user, module.User) - assert type(user) is module.User - assert user.name == "Bond" - assert user.age == 7 - - def test_basic_with_missing(self, module: Any) -> None: - user = self.round_trip_to_object(module.User()) - assert isinstance(user, module.User) - assert type(user) is module.User - assert user.name == MISSING - assert user.age == MISSING - - def test_nested(self, module: Any) -> None: - data = self.round_trip_to_object({"user": module.User("Bond", 7)}) - user = data["user"] - assert isinstance(user, module.User) - assert type(user) is module.User - assert user.name == "Bond" - assert user.age == 7 - - def test_nested_with_missing(self, module: Any) -> None: - data = self.round_trip_to_object({"user": module.User()}) - user = data["user"] - assert isinstance(user, module.User) - assert type(user) is module.User - assert user.name == MISSING - assert user.age == MISSING - - def test_list(self, module: Any) -> None: - lst = self.round_trip_to_object(module.UserList([module.User("Bond", 7)])) - assert isinstance(lst, module.UserList) - assert type(lst) is module.UserList - assert len(lst.list) == 1 - user = lst.list[0] - assert isinstance(user, module.User) - assert type(user) is module.User - assert user.name == "Bond" - assert user.age == 7 - - def test_list_with_missing(self, module: Any) -> None: - lst = self.round_trip_to_object(module.UserList) - assert isinstance(lst, module.UserList) - assert type(lst) is module.UserList - assert lst.list == MISSING - - def test_dict(self, module: Any) -> None: - user_dict = self.round_trip_to_object( - module.UserDict({"user007": module.User("Bond", 7)}) - ) - assert isinstance(user_dict, module.UserDict) - assert type(user_dict) is module.UserDict - assert len(user_dict.dict) == 1 - user = user_dict.dict["user007"] - assert isinstance(user, module.User) - assert type(user) is module.User - assert user.name == "Bond" - assert user.age == 7 - - def test_dict_with_missing(self, module: Any) -> None: - user_dict = self.round_trip_to_object(module.UserDict) - assert isinstance(user_dict, module.UserDict) - assert type(user_dict) is module.UserDict - assert user_dict.dict == MISSING - - def test_nested_object(self, module: Any) -> None: - nested = self.round_trip_to_object(module.NestedConfig) - assert isinstance(nested, module.NestedConfig) - assert type(nested) is module.NestedConfig - - assert nested.default_value == MISSING - - assert isinstance(nested.user_provided_default, module.Nested) - assert type(nested.user_provided_default) is module.Nested - assert nested.user_provided_default.with_default == 42 - - def test_nested_object_with_Any_ref_type(self, module: Any) -> None: - nested = self.round_trip_to_object(module.NestedWithAny) - assert isinstance(nested, module.NestedWithAny) - assert type(nested) is module.NestedWithAny - - assert isinstance(nested.var, module.Nested) - assert type(nested.var) is module.Nested - assert nested.var.with_default == 10 - - def test_str2user_instantiate(self, module: Any) -> None: - cfg = OmegaConf.structured(module.DictSubclass.Str2User()) - cfg.bond = module.User(name="James Bond", age=7) - data = self.round_trip_to_object(cfg) - - assert isinstance(data, module.DictSubclass.Str2User) - assert type(data) is module.DictSubclass.Str2User - assert type(data["bond"]) is module.User - assert data["bond"] == module.User("James Bond", 7) - - def test_str2user_with_field_instantiate(self, module: Any) -> None: - cfg = OmegaConf.structured(module.DictSubclass.Str2UserWithField()) - cfg.mp = module.User(name="Moneypenny", age=11) - data = self.round_trip_to_object(cfg) - - assert isinstance(data, module.DictSubclass.Str2UserWithField) - assert type(data) is module.DictSubclass.Str2UserWithField - assert type(data.foo) is module.User - assert data.foo == module.User("Bond", 7) - assert type(data["mp"]) is module.User - assert data["mp"] == module.User("Moneypenny", 11) - - def test_str2str_with_field_instantiate(self, module: Any) -> None: - cfg = OmegaConf.structured(module.DictSubclass.Str2StrWithField()) - cfg.hello = "world" - data = self.round_trip_to_object(cfg) - - assert isinstance(data, module.DictSubclass.Str2StrWithField) - assert type(data) is module.DictSubclass.Str2StrWithField - assert data.foo == "bar" - assert data["hello"] == "world" - def validate_frozen_impl(conf: DictConfig) -> None: with pytest.raises(ReadonlyConfigError): diff --git a/tests/test_to_container.py b/tests/test_to_container.py index e368c1fb8..99cb2ddbc 100644 --- a/tests/test_to_container.py +++ b/tests/test_to_container.py @@ -1,10 +1,11 @@ import re from enum import Enum +from importlib import import_module from typing import Any import pytest -from omegaconf import DictConfig, ListConfig, OmegaConf +from omegaconf import MISSING, DictConfig, ListConfig, OmegaConf from tests import Color, User @@ -178,3 +179,140 @@ def test_string_interpolation_with_readonly_parent() -> None: def test_to_container_missing_inter_no_resolve(src: Any, expected: Any) -> None: res = OmegaConf.to_container(src, resolve=False) assert res == expected + + +@pytest.mark.parametrize( + "class_type", + [ + "tests.structured_conf.data.dataclasses", + "tests.structured_conf.data.attr_classes", + ], +) +class TestInstantiateStructuredConfigs: + @pytest.fixture + def module(self, class_type: str) -> Any: + module: Any = import_module(class_type) + return module + + def round_trip_to_object(self, input_data: Any) -> Any: + serialized = OmegaConf.create(input_data) + round_tripped = OmegaConf.to_object(serialized) + return round_tripped + + def test_basic(self, module: Any) -> None: + user = self.round_trip_to_object(module.User("Bond", 7)) + assert isinstance(user, module.User) + assert type(user) is module.User + assert user.name == "Bond" + assert user.age == 7 + + def test_basic_with_missing(self, module: Any) -> None: + user = self.round_trip_to_object(module.User()) + assert isinstance(user, module.User) + assert type(user) is module.User + assert user.name == MISSING + assert user.age == MISSING + + def test_nested(self, module: Any) -> None: + data = self.round_trip_to_object({"user": module.User("Bond", 7)}) + user = data["user"] + assert isinstance(user, module.User) + assert type(user) is module.User + assert user.name == "Bond" + assert user.age == 7 + + def test_nested_with_missing(self, module: Any) -> None: + data = self.round_trip_to_object({"user": module.User()}) + user = data["user"] + assert isinstance(user, module.User) + assert type(user) is module.User + assert user.name == MISSING + assert user.age == MISSING + + def test_list(self, module: Any) -> None: + lst = self.round_trip_to_object(module.UserList([module.User("Bond", 7)])) + assert isinstance(lst, module.UserList) + assert type(lst) is module.UserList + assert len(lst.list) == 1 + user = lst.list[0] + assert isinstance(user, module.User) + assert type(user) is module.User + assert user.name == "Bond" + assert user.age == 7 + + def test_list_with_missing(self, module: Any) -> None: + lst = self.round_trip_to_object(module.UserList) + assert isinstance(lst, module.UserList) + assert type(lst) is module.UserList + assert lst.list == MISSING + + def test_dict(self, module: Any) -> None: + user_dict = self.round_trip_to_object( + module.UserDict({"user007": module.User("Bond", 7)}) + ) + assert isinstance(user_dict, module.UserDict) + assert type(user_dict) is module.UserDict + assert len(user_dict.dict) == 1 + user = user_dict.dict["user007"] + assert isinstance(user, module.User) + assert type(user) is module.User + assert user.name == "Bond" + assert user.age == 7 + + def test_dict_with_missing(self, module: Any) -> None: + user_dict = self.round_trip_to_object(module.UserDict) + assert isinstance(user_dict, module.UserDict) + assert type(user_dict) is module.UserDict + assert user_dict.dict == MISSING + + def test_nested_object(self, module: Any) -> None: + nested = self.round_trip_to_object(module.NestedConfig) + assert isinstance(nested, module.NestedConfig) + assert type(nested) is module.NestedConfig + + assert nested.default_value == MISSING + + assert isinstance(nested.user_provided_default, module.Nested) + assert type(nested.user_provided_default) is module.Nested + assert nested.user_provided_default.with_default == 42 + + def test_nested_object_with_Any_ref_type(self, module: Any) -> None: + nested = self.round_trip_to_object(module.NestedWithAny) + assert isinstance(nested, module.NestedWithAny) + assert type(nested) is module.NestedWithAny + + assert isinstance(nested.var, module.Nested) + assert type(nested.var) is module.Nested + assert nested.var.with_default == 10 + + def test_str2user_instantiate(self, module: Any) -> None: + cfg = OmegaConf.structured(module.DictSubclass.Str2User()) + cfg.bond = module.User(name="James Bond", age=7) + data = self.round_trip_to_object(cfg) + + assert isinstance(data, module.DictSubclass.Str2User) + assert type(data) is module.DictSubclass.Str2User + assert type(data["bond"]) is module.User + assert data["bond"] == module.User("James Bond", 7) + + def test_str2user_with_field_instantiate(self, module: Any) -> None: + cfg = OmegaConf.structured(module.DictSubclass.Str2UserWithField()) + cfg.mp = module.User(name="Moneypenny", age=11) + data = self.round_trip_to_object(cfg) + + assert isinstance(data, module.DictSubclass.Str2UserWithField) + assert type(data) is module.DictSubclass.Str2UserWithField + assert type(data.foo) is module.User + assert data.foo == module.User("Bond", 7) + assert type(data["mp"]) is module.User + assert data["mp"] == module.User("Moneypenny", 11) + + def test_str2str_with_field_instantiate(self, module: Any) -> None: + cfg = OmegaConf.structured(module.DictSubclass.Str2StrWithField()) + cfg.hello = "world" + data = self.round_trip_to_object(cfg) + + assert isinstance(data, module.DictSubclass.Str2StrWithField) + assert type(data) is module.DictSubclass.Str2StrWithField + assert data.foo == "bar" + assert data["hello"] == "world" From fa81a4d4bc35342ebd2a76942ce68653d1f41078 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Tue, 16 Feb 2021 18:25:36 -0600 Subject: [PATCH 31/85] Create get_structured_config_field_names function --- omegaconf/_utils.py | 19 +++++++++++++++++++ omegaconf/basecontainer.py | 4 ++-- tests/test_utils.py | 18 ++++++++++++++++++ 3 files changed, 39 insertions(+), 2 deletions(-) diff --git a/omegaconf/_utils.py b/omegaconf/_utils.py index 05b002d5e..e6cf918d4 100644 --- a/omegaconf/_utils.py +++ b/omegaconf/_utils.py @@ -192,6 +192,12 @@ def _resolve_forward(type_: Type[Any], module: str) -> Type[Any]: return type_ +def get_attr_class_field_names(obj: Any) -> List[str]: + is_type = isinstance(obj, type) + obj_type = obj if is_type else type(obj) + return [name for name in attr.fields_dict(obj_type).keys()] + + def get_attr_data(obj: Any, allow_objects: Optional[bool] = None) -> Dict[str, Any]: from omegaconf.omegaconf import OmegaConf, _maybe_wrap @@ -228,6 +234,10 @@ def get_attr_data(obj: Any, allow_objects: Optional[bool] = None) -> Dict[str, A return d +def get_dataclass_field_names(obj: Any) -> List[str]: + return [field.name for field in dataclasses.fields(obj)] + + def get_dataclass_data( obj: Any, allow_objects: Optional[bool] = None ) -> Dict[str, Any]: @@ -317,6 +327,15 @@ def is_structured_config_frozen(obj: Any) -> bool: return False +def get_structured_config_field_names(obj: Any) -> List[str]: + if is_dataclass(obj): + return get_dataclass_field_names(obj) + elif is_attr_class(obj): + return get_attr_class_field_names(obj) + else: + raise ValueError(f"Unsupported type: {type(obj).__name__}") + + def get_structured_config_data( obj: Any, allow_objects: Optional[bool] = None ) -> Dict[str, Any]: diff --git a/omegaconf/basecontainer.py b/omegaconf/basecontainer.py index 73ae24fb5..46d84250d 100644 --- a/omegaconf/basecontainer.py +++ b/omegaconf/basecontainer.py @@ -820,9 +820,9 @@ def _update_types(node: Node, ref_type: type, object_type: Optional[type]) -> No def _instantiate_structured_config_impl( retdict: Dict[str, Any], object_type: Type[Any] ) -> Any: - from ._utils import get_structured_config_data + from ._utils import get_structured_config_field_names - object_type_field_names = get_structured_config_data(object_type).keys() + object_type_field_names = get_structured_config_field_names(object_type) if issubclass(object_type, dict): # Extending dict as a subclass diff --git a/tests/test_utils.py b/tests/test_utils.py index cebf2bdb7..d2c3aa773 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -192,6 +192,24 @@ def test_get_structured_config_data(test_cls_or_obj: Any, expectation: Any) -> N assert d["dict1"] == {} +@mark.parametrize( + "test_cls_or_obj, expectation", + [ + (_TestDataclass, does_not_raise()), + (_TestDataclass(), does_not_raise()), + (_TestAttrsClass, does_not_raise()), + (_TestAttrsClass(), does_not_raise()), + ("invalid", raises(ValueError)), + ], +) +def test_get_structured_config_field_names( + test_cls_or_obj: Any, expectation: Any +) -> None: + with expectation: + field_names = _utils.get_structured_config_field_names(test_cls_or_obj) + assert field_names == ["x", "s", "b", "f", "e", "list1", "dict1"] + + @mark.parametrize( "test_cls", [ From 872850b96ada2890a5ce45248e23c50eb23a72d5 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Wed, 17 Feb 2021 12:04:50 -0600 Subject: [PATCH 32/85] OmegaConf.to_object: resolve=True by default --- omegaconf/omegaconf.py | 2 +- tests/test_to_container.py | 27 ++++++++++++++++++++++++--- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/omegaconf/omegaconf.py b/omegaconf/omegaconf.py index a543a530b..17dd31898 100644 --- a/omegaconf/omegaconf.py +++ b/omegaconf/omegaconf.py @@ -611,7 +611,7 @@ def to_container( def to_object( cfg: Any, *, - resolve: bool = False, + resolve: bool = True, enum_to_str: bool = False, ) -> Union[Dict[DictKeyType, Any], List[Any], None, str, Any]: """ diff --git a/tests/test_to_container.py b/tests/test_to_container.py index 99cb2ddbc..055ffa407 100644 --- a/tests/test_to_container.py +++ b/tests/test_to_container.py @@ -6,6 +6,7 @@ import pytest from omegaconf import MISSING, DictConfig, ListConfig, OmegaConf +from omegaconf.errors import InterpolationResolutionError from tests import Color, User @@ -194,9 +195,9 @@ def module(self, class_type: str) -> Any: module: Any = import_module(class_type) return module - def round_trip_to_object(self, input_data: Any) -> Any: + def round_trip_to_object(self, input_data: Any, **kwargs: Any) -> Any: serialized = OmegaConf.create(input_data) - round_tripped = OmegaConf.to_object(serialized) + round_tripped = OmegaConf.to_object(serialized, **kwargs) return round_tripped def test_basic(self, module: Any) -> None: @@ -276,8 +277,28 @@ def test_nested_object(self, module: Any) -> None: assert type(nested.user_provided_default) is module.Nested assert nested.user_provided_default.with_default == 42 + def test_to_object_resolve_is_True_by_default(self, module: Any) -> None: + interp = self.round_trip_to_object(module.Interpolation) + assert isinstance(interp, module.Interpolation) + assert type(interp) is module.Interpolation + + assert interp.z1 == 100 + assert interp.z2 == "100_200" + + def test_to_object_resolve_False(self, module: Any) -> None: + interp = self.round_trip_to_object(module.Interpolation, resolve=False) + assert isinstance(interp, module.Interpolation) + assert type(interp) is module.Interpolation + + assert interp.z1 == "${x}" + assert interp.z2 == "${x}_${y}" + + def test_to_object_InterpolationResolutionError(self, module: Any) -> None: + with pytest.raises(InterpolationResolutionError): + self.round_trip_to_object(module.NestedWithAny) + def test_nested_object_with_Any_ref_type(self, module: Any) -> None: - nested = self.round_trip_to_object(module.NestedWithAny) + nested = self.round_trip_to_object(module.NestedWithAny, resolve=False) assert isinstance(nested, module.NestedWithAny) assert type(nested) is module.NestedWithAny From 072e8d861ce6a8df13fdbe0d67979be3874068cd Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Wed, 17 Feb 2021 12:26:18 -0600 Subject: [PATCH 33/85] change _instantiate_structured_config_impl fn signature --- omegaconf/basecontainer.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/omegaconf/basecontainer.py b/omegaconf/basecontainer.py index 46d84250d..0bd6aafa9 100644 --- a/omegaconf/basecontainer.py +++ b/omegaconf/basecontainer.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from enum import Enum from textwrap import dedent -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import yaml @@ -244,11 +244,9 @@ def convert(val: Node) -> Any: else: retdict[key] = convert(node) - object_type = conf._metadata.object_type - if instantiate and is_structured_config(object_type): - assert object_type is not None + if instantiate and is_structured_config(conf._metadata.object_type): retstruct = _instantiate_structured_config_impl( - retdict=retdict, object_type=object_type + conf=conf, instance_data=retdict ) return retstruct else: @@ -818,24 +816,27 @@ def _update_types(node: Node, ref_type: type, object_type: Optional[type]) -> No def _instantiate_structured_config_impl( - retdict: Dict[str, Any], object_type: Type[Any] + conf: "DictConfig", instance_data: Dict[str, Any] ) -> Any: + """Instantiate an instance of `conf._metadata.object_type`, populated by `instance_data`.""" from ._utils import get_structured_config_field_names + object_type = conf._metadata.object_type + object_type_field_names = get_structured_config_field_names(object_type) if issubclass(object_type, dict): # Extending dict as a subclass retdict_field_items = { - k: v for k, v in retdict.items() if k in object_type_field_names + k: v for k, v in instance_data.items() if k in object_type_field_names } retdict_nonfield_items = { - k: v for k, v in retdict.items() if k not in object_type_field_names + k: v for k, v in instance_data.items() if k not in object_type_field_names } result = object_type(**retdict_field_items) result.update(retdict_nonfield_items) else: # normal structured config - assert set(retdict.keys()) <= set(object_type_field_names) - result = object_type(**retdict) + assert set(instance_data.keys()) <= set(object_type_field_names) + result = object_type(**instance_data) return result From 1953a47fdfb5b65725febeb3fec213e9641f2cc1 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Tue, 23 Feb 2021 12:50:09 -0600 Subject: [PATCH 34/85] separate positive and negative test cases --- tests/test_utils.py | 60 ++++++++++++++++++++------------------------- 1 file changed, 27 insertions(+), 33 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index d2c3aa773..74bc13c80 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -23,7 +23,7 @@ StringNode, ) from omegaconf.omegaconf import _node_wrap -from tests import Color, ConcretePlugin, IllegalType, Plugin, User, does_not_raise +from tests import Color, ConcretePlugin, IllegalType, Plugin, User @mark.parametrize( @@ -171,43 +171,37 @@ def test_valid_value_annotation_type(type_: type, expected: bool) -> None: @mark.parametrize( - "test_cls_or_obj, expectation", - [ - (_TestDataclass, does_not_raise()), - (_TestDataclass(), does_not_raise()), - (_TestAttrsClass, does_not_raise()), - (_TestAttrsClass(), does_not_raise()), - ("invalid", raises(ValueError)), - ], + "test_cls_or_obj", + [_TestDataclass, _TestDataclass(), _TestAttrsClass, _TestAttrsClass()], ) -def test_get_structured_config_data(test_cls_or_obj: Any, expectation: Any) -> None: - with expectation: - d = _utils.get_structured_config_data(test_cls_or_obj) - assert d["x"] == 10 - assert d["s"] == "foo" - assert d["b"] == bool(True) - assert d["f"] == 3.14 - assert d["e"] == _TestEnum.A - assert d["list1"] == [] - assert d["dict1"] == {} +def test_get_structured_config_data(test_cls_or_obj: Any) -> None: + d = _utils.get_structured_config_data(test_cls_or_obj) + assert d["x"] == 10 + assert d["s"] == "foo" + assert d["b"] == bool(True) + assert d["f"] == 3.14 + assert d["e"] == _TestEnum.A + assert d["list1"] == [] + assert d["dict1"] == {} + + +def test_get_structured_config_data_throws_ValueError() -> None: + with raises(ValueError): + _utils.get_structured_config_data("invalid") @mark.parametrize( - "test_cls_or_obj, expectation", - [ - (_TestDataclass, does_not_raise()), - (_TestDataclass(), does_not_raise()), - (_TestAttrsClass, does_not_raise()), - (_TestAttrsClass(), does_not_raise()), - ("invalid", raises(ValueError)), - ], + "test_cls_or_obj", + [_TestDataclass, _TestDataclass(), _TestAttrsClass, _TestAttrsClass()], ) -def test_get_structured_config_field_names( - test_cls_or_obj: Any, expectation: Any -) -> None: - with expectation: - field_names = _utils.get_structured_config_field_names(test_cls_or_obj) - assert field_names == ["x", "s", "b", "f", "e", "list1", "dict1"] +def test_get_structured_config_field_names(test_cls_or_obj: Any) -> None: + field_names = _utils.get_structured_config_field_names(test_cls_or_obj) + assert field_names == ["x", "s", "b", "f", "e", "list1", "dict1"] + + +def test_get_structured_config_field_names_throws_ValueError() -> None: + with raises(ValueError): + _utils.get_structured_config_field_names("invalid") @mark.parametrize( From a4af7f5466cb74cbc2cde39e6f6e447a1496a992 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Fri, 26 Feb 2021 06:37:48 -0600 Subject: [PATCH 35/85] switch order of cases in _instantiate_structured_config_impl --- omegaconf/basecontainer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/omegaconf/basecontainer.py b/omegaconf/basecontainer.py index 15e1fb78a..3e2f91735 100644 --- a/omegaconf/basecontainer.py +++ b/omegaconf/basecontainer.py @@ -812,7 +812,11 @@ def _instantiate_structured_config_impl( object_type = conf._metadata.object_type object_type_field_names = get_structured_config_field_names(object_type) - if issubclass(object_type, dict): + if not issubclass(object_type, dict): + # normal structured config + assert set(instance_data.keys()) <= set(object_type_field_names) + result = object_type(**instance_data) + else: # Extending dict as a subclass retdict_field_items = { @@ -823,8 +827,4 @@ def _instantiate_structured_config_impl( } result = object_type(**retdict_field_items) result.update(retdict_nonfield_items) - else: - # normal structured config - assert set(instance_data.keys()) <= set(object_type_field_names) - result = object_type(**instance_data) return result From efcc93bc39ec9ca5b16c49474c4c267d74c6ebd3 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Fri, 26 Feb 2021 06:37:48 -0600 Subject: [PATCH 36/85] switch order of cases in _instantiate_structured_config_impl --- omegaconf/basecontainer.py | 10 +++++----- tests/test_to_container.py | 7 +------ 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/omegaconf/basecontainer.py b/omegaconf/basecontainer.py index 15e1fb78a..3e2f91735 100644 --- a/omegaconf/basecontainer.py +++ b/omegaconf/basecontainer.py @@ -812,7 +812,11 @@ def _instantiate_structured_config_impl( object_type = conf._metadata.object_type object_type_field_names = get_structured_config_field_names(object_type) - if issubclass(object_type, dict): + if not issubclass(object_type, dict): + # normal structured config + assert set(instance_data.keys()) <= set(object_type_field_names) + result = object_type(**instance_data) + else: # Extending dict as a subclass retdict_field_items = { @@ -823,8 +827,4 @@ def _instantiate_structured_config_impl( } result = object_type(**retdict_field_items) result.update(retdict_nonfield_items) - else: - # normal structured config - assert set(instance_data.keys()) <= set(object_type_field_names) - result = object_type(**instance_data) return result diff --git a/tests/test_to_container.py b/tests/test_to_container.py index 3ce84950a..3b2c3f42e 100644 --- a/tests/test_to_container.py +++ b/tests/test_to_container.py @@ -96,12 +96,7 @@ def test_scmode_dict_config( assert ret == ex_dict_config def test_scmode_instantiate( - self, - cfg: Any, - ex_dict: Any, - ex_dict_config: Any, - ex_instantiate: Any, - key: Any, + self, cfg: Any, ex_dict: Any, ex_dict_config: Any, ex_instantiate: Any, key: Any ) -> None: ret = OmegaConf.to_container(cfg, structured_config_mode=SCMode.INSTANTIATE) assert ret == ex_instantiate From 5b5186117426c993cc9c1f934b6500eb121865fc Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Fri, 26 Feb 2021 06:57:03 -0600 Subject: [PATCH 37/85] regroup tests for extracting structured config info --- tests/test_utils.py | 60 ++++++++++++++++++++++----------------------- 1 file changed, 29 insertions(+), 31 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 7649d7a5b..341b29bb2 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -170,38 +170,36 @@ def test_valid_value_annotation_type(type_: type, expected: bool) -> None: assert valid_value_annotation_type(type_) == expected -@mark.parametrize( - "test_cls_or_obj", - [_TestDataclass, _TestDataclass(), _TestAttrsClass, _TestAttrsClass()], -) -def test_get_structured_config_data(test_cls_or_obj: Any) -> None: - d = _utils.get_structured_config_data(test_cls_or_obj) - assert d["x"] == 10 - assert d["s"] == "foo" - assert d["b"] == bool(True) - assert d["f"] == 3.14 - assert d["e"] == _TestEnum.A - assert d["list1"] == [] - assert d["dict1"] == {} - - -def test_get_structured_config_data_throws_ValueError() -> None: - with raises(ValueError): - _utils.get_structured_config_data("invalid") - - -@mark.parametrize( - "test_cls_or_obj", - [_TestDataclass, _TestDataclass(), _TestAttrsClass, _TestAttrsClass()], -) -def test_get_structured_config_field_names(test_cls_or_obj: Any) -> None: - field_names = _utils.get_structured_config_field_names(test_cls_or_obj) - assert field_names == ["x", "s", "b", "f", "e", "list1", "dict1"] - +class TestGetStructuredConfigInfo: + @mark.parametrize( + "test_cls_or_obj", + [_TestDataclass, _TestDataclass(), _TestAttrsClass, _TestAttrsClass()], + ) + def test_get_structured_config_data(self, test_cls_or_obj: Any) -> None: + d = _utils.get_structured_config_data(test_cls_or_obj) + assert d["x"] == 10 + assert d["s"] == "foo" + assert d["b"] == bool(True) + assert d["f"] == 3.14 + assert d["e"] == _TestEnum.A + assert d["list1"] == [] + assert d["dict1"] == {} + + def test_get_structured_config_data_throws_ValueError(self) -> None: + with raises(ValueError): + _utils.get_structured_config_data("invalid") + + @mark.parametrize( + "test_cls_or_obj", + [_TestDataclass, _TestDataclass(), _TestAttrsClass, _TestAttrsClass()], + ) + def test_get_structured_config_field_names(self, test_cls_or_obj: Any) -> None: + field_names = _utils.get_structured_config_field_names(test_cls_or_obj) + assert field_names == ["x", "s", "b", "f", "e", "list1", "dict1"] -def test_get_structured_config_field_names_throws_ValueError() -> None: - with raises(ValueError): - _utils.get_structured_config_field_names("invalid") + def test_get_structured_config_field_names_throws_ValueError(self) -> None: + with raises(ValueError): + _utils.get_structured_config_field_names("invalid") @mark.parametrize( From 19de3b4e4fd513a97ec98d6c231b23d1eb6ecb06 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Thu, 4 Mar 2021 01:25:15 -0600 Subject: [PATCH 38/85] Undo a stylistic change to tests/structured_conf/test_structured_config.py This change would be best left for another PR. --- tests/structured_conf/test_structured_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/structured_conf/test_structured_config.py b/tests/structured_conf/test_structured_config.py index 8bbcd4bf9..b525c734e 100644 --- a/tests/structured_conf/test_structured_config.py +++ b/tests/structured_conf/test_structured_config.py @@ -997,8 +997,8 @@ def test_color2color(self, module: Any) -> None: def test_str2user(self, module: Any) -> None: cfg = OmegaConf.structured(module.DictSubclass.Str2User()) - cfg.bond = module.User(name="James Bond", age=7) + cfg.bond = module.User(name="James Bond", age=7) assert cfg.bond.name == "James Bond" assert cfg.bond.age == 7 From 84c00ac27d3525fd47aadd28b86841c39de39dce Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Thu, 4 Mar 2021 02:06:52 -0600 Subject: [PATCH 39/85] add failing tests for throw if MISSING --- tests/test_to_container.py | 49 ++++++++++++++++++-------------------- 1 file changed, 23 insertions(+), 26 deletions(-) diff --git a/tests/test_to_container.py b/tests/test_to_container.py index 3b2c3f42e..a472c59c7 100644 --- a/tests/test_to_container.py +++ b/tests/test_to_container.py @@ -233,11 +233,8 @@ def test_basic(self, module: Any) -> None: assert user.age == 7 def test_basic_with_missing(self, module: Any) -> None: - user = self.round_trip_to_object(module.User()) - assert isinstance(user, module.User) - assert type(user) is module.User - assert user.name == MISSING - assert user.age == MISSING + with raises(MissingMandatoryValue): + self.round_trip_to_object(module.User()) def test_nested(self, module: Any) -> None: data = self.round_trip_to_object({"user": module.User("Bond", 7)}) @@ -248,12 +245,8 @@ def test_nested(self, module: Any) -> None: assert user.age == 7 def test_nested_with_missing(self, module: Any) -> None: - data = self.round_trip_to_object({"user": module.User()}) - user = data["user"] - assert isinstance(user, module.User) - assert type(user) is module.User - assert user.name == MISSING - assert user.age == MISSING + with raises(MissingMandatoryValue): + self.round_trip_to_object({"user": module.User()}) def test_list(self, module: Any) -> None: lst = self.round_trip_to_object(module.UserList([module.User("Bond", 7)])) @@ -267,10 +260,8 @@ def test_list(self, module: Any) -> None: assert user.age == 7 def test_list_with_missing(self, module: Any) -> None: - lst = self.round_trip_to_object(module.UserList) - assert isinstance(lst, module.UserList) - assert type(lst) is module.UserList - assert lst.list == MISSING + with raises(MissingMandatoryValue): + self.round_trip_to_object(module.UserList) def test_dict(self, module: Any) -> None: user_dict = self.round_trip_to_object( @@ -286,21 +277,27 @@ def test_dict(self, module: Any) -> None: assert user.age == 7 def test_dict_with_missing(self, module: Any) -> None: - user_dict = self.round_trip_to_object(module.UserDict) - assert isinstance(user_dict, module.UserDict) - assert type(user_dict) is module.UserDict - assert user_dict.dict == MISSING + with raises(MissingMandatoryValue): + user_dict = self.round_trip_to_object(module.UserDict) - def test_nested_object(self, module: Any) -> None: - nested = self.round_trip_to_object(module.NestedConfig) - assert isinstance(nested, module.NestedConfig) - assert type(nested) is module.NestedConfig + def test_nested_object_with_missing(self, module: Any) -> None: + with raises(MissingMandatoryValue): + self.round_trip_to_object(module.NestedConfig) - assert nested.default_value == MISSING + def test_nested_object(self, module: Any) -> None: + cfg = OmegaConf.structured(module.NestedConfig) + # fill in missing values: + cfg.default_value = module.NestedSubclass(mandatory_missing=123) + cfg.user_provided_default.mandatory_missing = 456 - assert isinstance(nested.user_provided_default, module.Nested) + nested = OmegaConf.to_object(cfg) + assert type(nested) is module.NestedConfig + assert type(nested.default_value) is module.NestedSubclass assert type(nested.user_provided_default) is module.Nested - assert nested.user_provided_default.with_default == 42 + + assert nested.default_value.mandatory_missing == 123 + assert nested.default_value.additional == 20 + assert nested.user_provided_default.mandatory_missing == 456 def test_to_object_resolve_is_True_by_default(self, module: Any) -> None: interp = self.round_trip_to_object(module.Interpolation) From 488a4b337fa0e120951ea18459187f619e1c6e18 Mon Sep 17 00:00:00 2001 From: Jasha10 <8935917+Jasha10@users.noreply.github.com> Date: Thu, 11 Mar 2021 05:00:23 -0600 Subject: [PATCH 40/85] Update omegaconf/_utils.py Co-authored-by: Olivier Delalleau <507137+odelalleau@users.noreply.github.com> --- omegaconf/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/omegaconf/_utils.py b/omegaconf/_utils.py index 53b920ae8..74ec72d18 100644 --- a/omegaconf/_utils.py +++ b/omegaconf/_utils.py @@ -195,7 +195,7 @@ def _resolve_forward(type_: Type[Any], module: str) -> Type[Any]: def get_attr_class_field_names(obj: Any) -> List[str]: is_type = isinstance(obj, type) obj_type = obj if is_type else type(obj) - return [name for name in attr.fields_dict(obj_type).keys()] + return list(attr.fields_dict(obj_type)) def get_attr_data(obj: Any, allow_objects: Optional[bool] = None) -> Dict[str, Any]: From bbbb2458563499be3facbc7361c08b3ad5e7a85a Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Thu, 11 Mar 2021 19:13:07 -0600 Subject: [PATCH 41/85] fix mypy and flake8 issues --- tests/test_to_container.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_to_container.py b/tests/test_to_container.py index a472c59c7..93efb750c 100644 --- a/tests/test_to_container.py +++ b/tests/test_to_container.py @@ -5,7 +5,7 @@ from pytest import fixture, mark, param, raises, warns -from omegaconf import MISSING, DictConfig, ListConfig, OmegaConf, SCMode +from omegaconf import DictConfig, ListConfig, MissingMandatoryValue, OmegaConf, SCMode from omegaconf.errors import InterpolationResolutionError from tests import Color, User @@ -278,7 +278,7 @@ def test_dict(self, module: Any) -> None: def test_dict_with_missing(self, module: Any) -> None: with raises(MissingMandatoryValue): - user_dict = self.round_trip_to_object(module.UserDict) + self.round_trip_to_object(module.UserDict) def test_nested_object_with_missing(self, module: Any) -> None: with raises(MissingMandatoryValue): @@ -290,7 +290,7 @@ def test_nested_object(self, module: Any) -> None: cfg.default_value = module.NestedSubclass(mandatory_missing=123) cfg.user_provided_default.mandatory_missing = 456 - nested = OmegaConf.to_object(cfg) + nested: Any = OmegaConf.to_object(cfg) assert type(nested) is module.NestedConfig assert type(nested.default_value) is module.NestedSubclass assert type(nested.user_provided_default) is module.Nested From 37e055facb9d13b19b5d235368b4d5753d43f781 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Thu, 11 Mar 2021 22:51:20 -0600 Subject: [PATCH 42/85] implement MissingMandatoryValue in case of MISSING param to dataclass instance --- omegaconf/basecontainer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/omegaconf/basecontainer.py b/omegaconf/basecontainer.py index 222495289..73d000842 100644 --- a/omegaconf/basecontainer.py +++ b/omegaconf/basecontainer.py @@ -236,6 +236,8 @@ def convert(val: Node) -> Any: if structured_config_mode == SCMode.INSTANTIATE and is_structured_config( conf._metadata.object_type ): + if any(_is_missing_literal(value) for value in retdict.values()): + raise MissingMandatoryValue() retstruct = _instantiate_structured_config_impl( conf=conf, instance_data=retdict ) From 961b9fd220389a8d74ac34a2d3cda888797073ec Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Thu, 11 Mar 2021 22:52:07 -0600 Subject: [PATCH 43/85] update a test to reflect new behavior r.e. MISSING --- tests/test_to_container.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_to_container.py b/tests/test_to_container.py index 93efb750c..b1b1aa3c6 100644 --- a/tests/test_to_container.py +++ b/tests/test_to_container.py @@ -320,13 +320,16 @@ def test_to_object_InterpolationResolutionError(self, module: Any) -> None: self.round_trip_to_object(module.NestedWithAny) def test_nested_object_with_Any_ref_type(self, module: Any) -> None: - nested = self.round_trip_to_object(module.NestedWithAny, resolve=False) + cfg = OmegaConf.structured(module.NestedWithAny()) + cfg.var.mandatory_missing = 123 + nested = self.round_trip_to_object(cfg, resolve=False) assert isinstance(nested, module.NestedWithAny) assert type(nested) is module.NestedWithAny assert isinstance(nested.var, module.Nested) assert type(nested.var) is module.Nested assert nested.var.with_default == 10 + assert nested.var.mandatory_missing == 123 def test_str2user_instantiate(self, module: Any) -> None: cfg = OmegaConf.structured(module.DictSubclass.Str2User()) From 7f8addbb7ca846edd672060c6d64fae494d167f6 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Thu, 11 Mar 2021 22:57:12 -0600 Subject: [PATCH 44/85] use correct-typed value in test of KeyValidationError --- tests/structured_conf/test_structured_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/structured_conf/test_structured_config.py b/tests/structured_conf/test_structured_config.py index b525c734e..258a7bfe2 100644 --- a/tests/structured_conf/test_structured_config.py +++ b/tests/structured_conf/test_structured_config.py @@ -1028,7 +1028,7 @@ def test_str2user_with_field(self, module: Any) -> None: with raises(KeyValidationError): # bad key - cfg[Color.BLUE] = "nope" + cfg[Color.BLUE] = cfg.mp def test_str2str_with_field(self, module: Any) -> None: cfg = OmegaConf.structured(module.DictSubclass.Str2StrWithField()) From a10fa5bdca987120e33a26a7e57fcd588aa8f2e8 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Thu, 11 Mar 2021 23:00:12 -0600 Subject: [PATCH 45/85] modify to_object docstring --- omegaconf/omegaconf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/omegaconf/omegaconf.py b/omegaconf/omegaconf.py index 761fcc538..8159dfe3e 100644 --- a/omegaconf/omegaconf.py +++ b/omegaconf/omegaconf.py @@ -635,7 +635,7 @@ def to_object( Any DictConfig objects backed by dataclasses or attrs classes are instantiated as instances of those backing classes. - This is an alias for OmegaConf.to_container(..., resolve=True, structured_config_mode=SCMode.INSTANTIATE) + This is an alias for OmegaConf.to_container(..., structured_config_mode=SCMode.INSTANTIATE) :param cfg: the config to convert :param resolve: True to resolve all values From 6b014aea0ad0185722942a1315e6af22b2a9bdfe Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Thu, 11 Mar 2021 23:03:08 -0600 Subject: [PATCH 46/85] use a set for _instantiate_structured_config_impl field names --- omegaconf/basecontainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/omegaconf/basecontainer.py b/omegaconf/basecontainer.py index 73d000842..0e129d643 100644 --- a/omegaconf/basecontainer.py +++ b/omegaconf/basecontainer.py @@ -790,7 +790,7 @@ def _instantiate_structured_config_impl( object_type = conf._metadata.object_type - object_type_field_names = get_structured_config_field_names(object_type) + object_type_field_names = set(get_structured_config_field_names(object_type)) if not issubclass(object_type, dict): # normal structured config assert set(instance_data.keys()) <= set(object_type_field_names) From 4cefc6d859546949f65917b98ee3f36166f2104c Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Thu, 11 Mar 2021 23:37:07 -0600 Subject: [PATCH 47/85] remove redundant call to set() --- omegaconf/basecontainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/omegaconf/basecontainer.py b/omegaconf/basecontainer.py index 0e129d643..bc77a8c77 100644 --- a/omegaconf/basecontainer.py +++ b/omegaconf/basecontainer.py @@ -793,7 +793,7 @@ def _instantiate_structured_config_impl( object_type_field_names = set(get_structured_config_field_names(object_type)) if not issubclass(object_type, dict): # normal structured config - assert set(instance_data.keys()) <= set(object_type_field_names) + assert set(instance_data.keys()) <= object_type_field_names result = object_type(**instance_data) else: # Extending dict as a subclass From b2a5ab2ce088ea24d1563a43829bfef2daf455db Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Thu, 11 Mar 2021 23:46:40 -0600 Subject: [PATCH 48/85] refactor TestInstantiateStructuredConfigs - parametrize the `module` fixture directly - reorder tests for increased consistency (test non-missing case before missing case) --- tests/test_to_container.py | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/tests/test_to_container.py b/tests/test_to_container.py index b1b1aa3c6..880232455 100644 --- a/tests/test_to_container.py +++ b/tests/test_to_container.py @@ -207,18 +207,15 @@ def test_to_container_missing_inter_no_resolve(src: Any, expected: Any) -> None: assert res == expected -@mark.parametrize( - "class_type", - [ - "tests.structured_conf.data.dataclasses", - "tests.structured_conf.data.attr_classes", - ], -) class TestInstantiateStructuredConfigs: - @fixture - def module(self, class_type: str) -> Any: - module: Any = import_module(class_type) - return module + @fixture( + params=[ + "tests.structured_conf.data.dataclasses", + "tests.structured_conf.data.attr_classes", + ] + ) + def module(self, request: Any) -> Any: + return import_module(request.param) def round_trip_to_object(self, input_data: Any, **kwargs: Any) -> Any: serialized = OmegaConf.create(input_data) @@ -280,10 +277,6 @@ def test_dict_with_missing(self, module: Any) -> None: with raises(MissingMandatoryValue): self.round_trip_to_object(module.UserDict) - def test_nested_object_with_missing(self, module: Any) -> None: - with raises(MissingMandatoryValue): - self.round_trip_to_object(module.NestedConfig) - def test_nested_object(self, module: Any) -> None: cfg = OmegaConf.structured(module.NestedConfig) # fill in missing values: @@ -299,6 +292,10 @@ def test_nested_object(self, module: Any) -> None: assert nested.default_value.additional == 20 assert nested.user_provided_default.mandatory_missing == 456 + def test_nested_object_with_missing(self, module: Any) -> None: + with raises(MissingMandatoryValue): + self.round_trip_to_object(module.NestedConfig) + def test_to_object_resolve_is_True_by_default(self, module: Any) -> None: interp = self.round_trip_to_object(module.Interpolation) assert isinstance(interp, module.Interpolation) From d28ae5d7a8e6fc00eb0ba1e818482fa1e1c5cd5b Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Thu, 11 Mar 2021 23:52:53 -0600 Subject: [PATCH 49/85] TestInstantiateStructuredConfigs: remove redundant isinstance assertions Testing `isinstance(a, A)` is redundant if we are testing `type(a) is A` on the next line. --- tests/test_to_container.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/tests/test_to_container.py b/tests/test_to_container.py index 880232455..0b8024706 100644 --- a/tests/test_to_container.py +++ b/tests/test_to_container.py @@ -224,7 +224,6 @@ def round_trip_to_object(self, input_data: Any, **kwargs: Any) -> Any: def test_basic(self, module: Any) -> None: user = self.round_trip_to_object(module.User("Bond", 7)) - assert isinstance(user, module.User) assert type(user) is module.User assert user.name == "Bond" assert user.age == 7 @@ -236,7 +235,6 @@ def test_basic_with_missing(self, module: Any) -> None: def test_nested(self, module: Any) -> None: data = self.round_trip_to_object({"user": module.User("Bond", 7)}) user = data["user"] - assert isinstance(user, module.User) assert type(user) is module.User assert user.name == "Bond" assert user.age == 7 @@ -247,11 +245,9 @@ def test_nested_with_missing(self, module: Any) -> None: def test_list(self, module: Any) -> None: lst = self.round_trip_to_object(module.UserList([module.User("Bond", 7)])) - assert isinstance(lst, module.UserList) assert type(lst) is module.UserList assert len(lst.list) == 1 user = lst.list[0] - assert isinstance(user, module.User) assert type(user) is module.User assert user.name == "Bond" assert user.age == 7 @@ -264,11 +260,9 @@ def test_dict(self, module: Any) -> None: user_dict = self.round_trip_to_object( module.UserDict({"user007": module.User("Bond", 7)}) ) - assert isinstance(user_dict, module.UserDict) assert type(user_dict) is module.UserDict assert len(user_dict.dict) == 1 user = user_dict.dict["user007"] - assert isinstance(user, module.User) assert type(user) is module.User assert user.name == "Bond" assert user.age == 7 @@ -298,7 +292,6 @@ def test_nested_object_with_missing(self, module: Any) -> None: def test_to_object_resolve_is_True_by_default(self, module: Any) -> None: interp = self.round_trip_to_object(module.Interpolation) - assert isinstance(interp, module.Interpolation) assert type(interp) is module.Interpolation assert interp.z1 == 100 @@ -306,7 +299,6 @@ def test_to_object_resolve_is_True_by_default(self, module: Any) -> None: def test_to_object_resolve_False(self, module: Any) -> None: interp = self.round_trip_to_object(module.Interpolation, resolve=False) - assert isinstance(interp, module.Interpolation) assert type(interp) is module.Interpolation assert interp.z1 == "${x}" @@ -320,10 +312,8 @@ def test_nested_object_with_Any_ref_type(self, module: Any) -> None: cfg = OmegaConf.structured(module.NestedWithAny()) cfg.var.mandatory_missing = 123 nested = self.round_trip_to_object(cfg, resolve=False) - assert isinstance(nested, module.NestedWithAny) assert type(nested) is module.NestedWithAny - assert isinstance(nested.var, module.Nested) assert type(nested.var) is module.Nested assert nested.var.with_default == 10 assert nested.var.mandatory_missing == 123 @@ -333,7 +323,6 @@ def test_str2user_instantiate(self, module: Any) -> None: cfg.bond = module.User(name="James Bond", age=7) data = self.round_trip_to_object(cfg) - assert isinstance(data, module.DictSubclass.Str2User) assert type(data) is module.DictSubclass.Str2User assert type(data["bond"]) is module.User assert data["bond"] == module.User("James Bond", 7) @@ -343,7 +332,6 @@ def test_str2user_with_field_instantiate(self, module: Any) -> None: cfg.mp = module.User(name="Moneypenny", age=11) data = self.round_trip_to_object(cfg) - assert isinstance(data, module.DictSubclass.Str2UserWithField) assert type(data) is module.DictSubclass.Str2UserWithField assert type(data.foo) is module.User assert data.foo == module.User("Bond", 7) @@ -355,7 +343,6 @@ def test_str2str_with_field_instantiate(self, module: Any) -> None: cfg.hello = "world" data = self.round_trip_to_object(cfg) - assert isinstance(data, module.DictSubclass.Str2StrWithField) assert type(data) is module.DictSubclass.Str2StrWithField assert data.foo == "bar" assert data["hello"] == "world" From 62f34cf190f4ae6cb439a05cabcf7b1cbac791a3 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Fri, 12 Mar 2021 00:22:22 -0600 Subject: [PATCH 50/85] Use setattr(instance, k, v) when structured config has extra fields --- omegaconf/basecontainer.py | 20 ++++++++++---------- tests/test_to_container.py | 18 +++++++++++++++++- 2 files changed, 27 insertions(+), 11 deletions(-) diff --git a/omegaconf/basecontainer.py b/omegaconf/basecontainer.py index bc77a8c77..4bdadfa2e 100644 --- a/omegaconf/basecontainer.py +++ b/omegaconf/basecontainer.py @@ -791,19 +791,19 @@ def _instantiate_structured_config_impl( object_type = conf._metadata.object_type object_type_field_names = set(get_structured_config_field_names(object_type)) + + retdict_field_items = { + k: v for k, v in instance_data.items() if k in object_type_field_names + } + retdict_nonfield_items = { + k: v for k, v in instance_data.items() if k not in object_type_field_names + } + result = object_type(**retdict_field_items) if not issubclass(object_type, dict): # normal structured config - assert set(instance_data.keys()) <= object_type_field_names - result = object_type(**instance_data) + for k, v in retdict_nonfield_items.items(): + setattr(result, k, v) else: # Extending dict as a subclass - - retdict_field_items = { - k: v for k, v in instance_data.items() if k in object_type_field_names - } - retdict_nonfield_items = { - k: v for k, v in instance_data.items() if k not in object_type_field_names - } - result = object_type(**retdict_field_items) result.update(retdict_nonfield_items) return result diff --git a/tests/test_to_container.py b/tests/test_to_container.py index 0b8024706..188ec3ea3 100644 --- a/tests/test_to_container.py +++ b/tests/test_to_container.py @@ -5,7 +5,14 @@ from pytest import fixture, mark, param, raises, warns -from omegaconf import DictConfig, ListConfig, MissingMandatoryValue, OmegaConf, SCMode +from omegaconf import ( + DictConfig, + ListConfig, + MissingMandatoryValue, + OmegaConf, + SCMode, + open_dict, +) from omegaconf.errors import InterpolationResolutionError from tests import Color, User @@ -347,6 +354,15 @@ def test_str2str_with_field_instantiate(self, module: Any) -> None: assert data.foo == "bar" assert data["hello"] == "world" + def test_setattr_for_user_with_extra_field(self, module: Any) -> None: + cfg = OmegaConf.structured(module.User(name="James Bond", age=7)) + with open_dict(cfg): + cfg.extra_field = 123 + + user: Any = OmegaConf.to_object(cfg) + assert type(user) is module.User + assert user.extra_field == 123 + class TestEnumToStr: """Test the `enum_to_str` argument to the `OmegaConf.to_container function`""" From 249ac36936eb88ec87d97fa945a3145572758556 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Fri, 12 Mar 2021 01:03:20 -0600 Subject: [PATCH 51/85] add news fragment --- news/472.feature | 1 + 1 file changed, 1 insertion(+) create mode 100644 news/472.feature diff --git a/news/472.feature b/news/472.feature new file mode 100644 index 000000000..6b05290d1 --- /dev/null +++ b/news/472.feature @@ -0,0 +1 @@ +Add the OmegaConf.to_object method, which converts Structured Config objects back to native dataclasses or attrs classes. From 0869121c93c5a0974486263205f3f621791e20f9 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Fri, 12 Mar 2021 01:05:27 -0600 Subject: [PATCH 52/85] refactoring: rename variables --- omegaconf/basecontainer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/omegaconf/basecontainer.py b/omegaconf/basecontainer.py index 4bdadfa2e..30a3f216c 100644 --- a/omegaconf/basecontainer.py +++ b/omegaconf/basecontainer.py @@ -792,18 +792,18 @@ def _instantiate_structured_config_impl( object_type_field_names = set(get_structured_config_field_names(object_type)) - retdict_field_items = { + field_items = { k: v for k, v in instance_data.items() if k in object_type_field_names } - retdict_nonfield_items = { + nonfield_items = { k: v for k, v in instance_data.items() if k not in object_type_field_names } - result = object_type(**retdict_field_items) + result = object_type(**field_items) if not issubclass(object_type, dict): # normal structured config - for k, v in retdict_nonfield_items.items(): + for k, v in nonfield_items.items(): setattr(result, k, v) else: # Extending dict as a subclass - result.update(retdict_nonfield_items) + result.update(nonfield_items) return result From 3a471320f627a67de71a8297c41bf0f704dfd529 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Fri, 12 Mar 2021 19:16:45 -0600 Subject: [PATCH 53/85] Test error message for MissingMandatoryValue --- omegaconf/basecontainer.py | 24 ++++++++++++++++-------- tests/test_errors.py | 12 ++++++++++++ 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/omegaconf/basecontainer.py b/omegaconf/basecontainer.py index 30a3f216c..d229c171d 100644 --- a/omegaconf/basecontainer.py +++ b/omegaconf/basecontainer.py @@ -236,8 +236,6 @@ def convert(val: Node) -> Any: if structured_config_mode == SCMode.INSTANTIATE and is_structured_config( conf._metadata.object_type ): - if any(_is_missing_literal(value) for value in retdict.values()): - raise MissingMandatoryValue() retstruct = _instantiate_structured_config_impl( conf=conf, instance_data=retdict ) @@ -792,12 +790,22 @@ def _instantiate_structured_config_impl( object_type_field_names = set(get_structured_config_field_names(object_type)) - field_items = { - k: v for k, v in instance_data.items() if k in object_type_field_names - } - nonfield_items = { - k: v for k, v in instance_data.items() if k not in object_type_field_names - } + field_items: Dict[str, Any] = {} + nonfield_items: Dict[str, Any] = {} + for k, v in instance_data.items(): + if _is_missing_literal(v): + conf._format_and_raise( + key=k, + value=None, + cause=MissingMandatoryValue( + "Structured Config has Missing Mandatory Value: $KEY" + ), + ) + if k in object_type_field_names: + field_items[k] = v + else: + nonfield_items[k] = v + result = object_type(**field_items) if not issubclass(object_type, dict): # normal structured config diff --git a/tests/test_errors.py b/tests/test_errors.py index e5675c77e..f158d5c09 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -1211,6 +1211,18 @@ def finalize(self, cfg: Any) -> None: ), id="list,readonly:del", ), + # to_object + pytest.param( + Expected( + create=lambda: OmegaConf.structured(User), + op=lambda cfg: OmegaConf.to_object(cfg), + exception_type=MissingMandatoryValue, + msg="Structured Config has Missing Mandatory Value: name", + key="name", + child_node=lambda cfg: cfg._get_node("name"), + ), + id="to_object:structured-missing-field", + ), ] From 46aadbe87135f22e2982a654d5192d7792b3ce6a Mon Sep 17 00:00:00 2001 From: Jasha10 <8935917+Jasha10@users.noreply.github.com> Date: Fri, 12 Mar 2021 19:26:12 -0600 Subject: [PATCH 54/85] Formatting: delete whitespace --- omegaconf/basecontainer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/omegaconf/basecontainer.py b/omegaconf/basecontainer.py index d229c171d..289ccaabe 100644 --- a/omegaconf/basecontainer.py +++ b/omegaconf/basecontainer.py @@ -787,7 +787,6 @@ def _instantiate_structured_config_impl( from ._utils import get_structured_config_field_names object_type = conf._metadata.object_type - object_type_field_names = set(get_structured_config_field_names(object_type)) field_items: Dict[str, Any] = {} From c581548edb11709db46ff8982b25c12a50bca625 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Sat, 13 Mar 2021 10:15:22 -0600 Subject: [PATCH 55/85] include $OBJECT_TYPE in MissingMandatoryValue err msg --- omegaconf/basecontainer.py | 2 +- tests/test_errors.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/omegaconf/basecontainer.py b/omegaconf/basecontainer.py index 289ccaabe..c6b13c782 100644 --- a/omegaconf/basecontainer.py +++ b/omegaconf/basecontainer.py @@ -797,7 +797,7 @@ def _instantiate_structured_config_impl( key=k, value=None, cause=MissingMandatoryValue( - "Structured Config has Missing Mandatory Value: $KEY" + "Structured config of type `$OBJECT_TYPE` has missing mandatory value: $KEY" ), ) if k in object_type_field_names: diff --git a/tests/test_errors.py b/tests/test_errors.py index f158d5c09..314f1465e 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -1217,7 +1217,7 @@ def finalize(self, cfg: Any) -> None: create=lambda: OmegaConf.structured(User), op=lambda cfg: OmegaConf.to_object(cfg), exception_type=MissingMandatoryValue, - msg="Structured Config has Missing Mandatory Value: name", + msg="Structured config of type `User` has missing mandatory value: name", key="name", child_node=lambda cfg: cfg._get_node("name"), ), From 2cf460f9b2187369bce8576ea30d67ff7920ceca Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Mon, 15 Mar 2021 14:51:43 -0500 Subject: [PATCH 56/85] change _instantiate_structured_config_impl to an instance method --- omegaconf/basecontainer.py | 73 ++++++++++++++++++-------------------- 1 file changed, 35 insertions(+), 38 deletions(-) diff --git a/omegaconf/basecontainer.py b/omegaconf/basecontainer.py index c6b13c782..f213755b6 100644 --- a/omegaconf/basecontainer.py +++ b/omegaconf/basecontainer.py @@ -236,8 +236,8 @@ def convert(val: Node) -> Any: if structured_config_mode == SCMode.INSTANTIATE and is_structured_config( conf._metadata.object_type ): - retstruct = _instantiate_structured_config_impl( - conf=conf, instance_data=retdict + retstruct = conf._instantiate_structured_config_impl( + instance_data=retdict ) return retstruct else: @@ -265,6 +265,39 @@ def convert(val: Node) -> Any: assert False + def _instantiate_structured_config_impl(conf, instance_data: Dict[str, Any]) -> Any: + """Instantiate an instance of `conf._metadata.object_type`, populated by `instance_data`.""" + from ._utils import get_structured_config_field_names + + object_type = conf._metadata.object_type + object_type_field_names = set(get_structured_config_field_names(object_type)) + + field_items: Dict[str, Any] = {} + nonfield_items: Dict[str, Any] = {} + for k, v in instance_data.items(): + if _is_missing_literal(v): + conf._format_and_raise( + key=k, + value=None, + cause=MissingMandatoryValue( + "Structured config of type `$OBJECT_TYPE` has missing mandatory value: $KEY" + ), + ) + if k in object_type_field_names: + field_items[k] = v + else: + nonfield_items[k] = v + + result = object_type(**field_items) + if not issubclass(object_type, dict): + # normal structured config + for k, v in nonfield_items.items(): + setattr(result, k, v) + else: + # Extending dict as a subclass + result.update(nonfield_items) + return result + def pretty(self, resolve: bool = False, sort_keys: bool = False) -> str: from omegaconf import OmegaConf @@ -778,39 +811,3 @@ def _update_types(node: Node, ref_type: type, object_type: Optional[type]) -> No if new_ref_type is not Any: node._metadata.ref_type = new_ref_type node._metadata.optional = new_is_optional - - -def _instantiate_structured_config_impl( - conf: "DictConfig", instance_data: Dict[str, Any] -) -> Any: - """Instantiate an instance of `conf._metadata.object_type`, populated by `instance_data`.""" - from ._utils import get_structured_config_field_names - - object_type = conf._metadata.object_type - object_type_field_names = set(get_structured_config_field_names(object_type)) - - field_items: Dict[str, Any] = {} - nonfield_items: Dict[str, Any] = {} - for k, v in instance_data.items(): - if _is_missing_literal(v): - conf._format_and_raise( - key=k, - value=None, - cause=MissingMandatoryValue( - "Structured config of type `$OBJECT_TYPE` has missing mandatory value: $KEY" - ), - ) - if k in object_type_field_names: - field_items[k] = v - else: - nonfield_items[k] = v - - result = object_type(**field_items) - if not issubclass(object_type, dict): - # normal structured config - for k, v in nonfield_items.items(): - setattr(result, k, v) - else: - # Extending dict as a subclass - result.update(nonfield_items) - return result From d6e9749cacf127bcd3196854b6c9945014537cbd Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Mon, 15 Mar 2021 14:59:19 -0500 Subject: [PATCH 57/85] simplify `retdict` & `retstruct` to `ret` --- omegaconf/basecontainer.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/omegaconf/basecontainer.py b/omegaconf/basecontainer.py index f213755b6..90e90ccd1 100644 --- a/omegaconf/basecontainer.py +++ b/omegaconf/basecontainer.py @@ -213,7 +213,7 @@ def convert(val: Node) -> Any: ): return conf - retdict: Dict[str, Any] = {} + ret: Any = {} for key in conf.keys(): node = conf._get_node(key) assert isinstance(node, Node) @@ -224,24 +224,20 @@ def convert(val: Node) -> Any: if enum_to_str and isinstance(key, Enum): key = f"{key.name}" if isinstance(node, Container): - retdict[key] = BaseContainer._to_content( + ret[key] = BaseContainer._to_content( node, resolve=resolve, enum_to_str=enum_to_str, structured_config_mode=structured_config_mode, ) else: - retdict[key] = convert(node) + ret[key] = convert(node) if structured_config_mode == SCMode.INSTANTIATE and is_structured_config( conf._metadata.object_type ): - retstruct = conf._instantiate_structured_config_impl( - instance_data=retdict - ) - return retstruct - else: - return retdict + ret = conf._instantiate_structured_config_impl(instance_data=ret) + return ret elif isinstance(conf, ListConfig): retlist: List[Any] = [] for index in range(len(conf)): From f16ad229015723c6e211d6e285e6566f0c698679 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Mon, 15 Mar 2021 15:59:31 -0500 Subject: [PATCH 58/85] rename `conf` -> `self` in _instantiate_structured_config_impl --- omegaconf/basecontainer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/omegaconf/basecontainer.py b/omegaconf/basecontainer.py index 90e90ccd1..54c26fb0a 100644 --- a/omegaconf/basecontainer.py +++ b/omegaconf/basecontainer.py @@ -261,18 +261,18 @@ def convert(val: Node) -> Any: assert False - def _instantiate_structured_config_impl(conf, instance_data: Dict[str, Any]) -> Any: - """Instantiate an instance of `conf._metadata.object_type`, populated by `instance_data`.""" + def _instantiate_structured_config_impl(self, instance_data: Dict[str, Any]) -> Any: + """Instantiate an instance of `self._metadata.object_type`, populated by `instance_data`.""" from ._utils import get_structured_config_field_names - object_type = conf._metadata.object_type + object_type = self._metadata.object_type object_type_field_names = set(get_structured_config_field_names(object_type)) field_items: Dict[str, Any] = {} nonfield_items: Dict[str, Any] = {} for k, v in instance_data.items(): if _is_missing_literal(v): - conf._format_and_raise( + self._format_and_raise( key=k, value=None, cause=MissingMandatoryValue( From 2bf73b024af397bf60baaf7f713cb3af0c3d4def Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Tue, 16 Mar 2021 12:29:49 -0500 Subject: [PATCH 59/85] remove `resolve` arg from `to_object` --- omegaconf/omegaconf.py | 4 +--- tests/test_to_container.py | 14 +++++++++++--- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/omegaconf/omegaconf.py b/omegaconf/omegaconf.py index 8159dfe3e..b294f15f2 100644 --- a/omegaconf/omegaconf.py +++ b/omegaconf/omegaconf.py @@ -627,7 +627,6 @@ def to_container( def to_object( cfg: Any, *, - resolve: bool = True, enum_to_str: bool = False, ) -> Union[Dict[DictKeyType, Any], List[Any], None, str, Any]: """ @@ -638,13 +637,12 @@ def to_object( This is an alias for OmegaConf.to_container(..., structured_config_mode=SCMode.INSTANTIATE) :param cfg: the config to convert - :param resolve: True to resolve all values :param enum_to_str: True to convert Enum values to strings :return: A dict or a list or dataclass representing this config. """ return OmegaConf.to_container( cfg=cfg, - resolve=resolve, + resolve=True, enum_to_str=enum_to_str, structured_config_mode=SCMode.INSTANTIATE, ) diff --git a/tests/test_to_container.py b/tests/test_to_container.py index 188ec3ea3..439a9047e 100644 --- a/tests/test_to_container.py +++ b/tests/test_to_container.py @@ -304,8 +304,13 @@ def test_to_object_resolve_is_True_by_default(self, module: Any) -> None: assert interp.z1 == 100 assert interp.z2 == "100_200" - def test_to_object_resolve_False(self, module: Any) -> None: - interp = self.round_trip_to_object(module.Interpolation, resolve=False) + def test_to_container_INSTANTIATE_resolve_False(self, module: Any) -> None: + """Test the lower level `to_container` API with SCMode.INSTANTIATE and resolve=False""" + serialized = OmegaConf.structured(module.Interpolation) + interp = OmegaConf.to_container( + serialized, resolve=False, structured_config_mode=SCMode.INSTANTIATE + ) + assert isinstance(interp, module.Interpolation) assert type(interp) is module.Interpolation assert interp.z1 == "${x}" @@ -318,12 +323,15 @@ def test_to_object_InterpolationResolutionError(self, module: Any) -> None: def test_nested_object_with_Any_ref_type(self, module: Any) -> None: cfg = OmegaConf.structured(module.NestedWithAny()) cfg.var.mandatory_missing = 123 - nested = self.round_trip_to_object(cfg, resolve=False) + with open_dict(cfg): + cfg.value_at_root = 456 + nested = self.round_trip_to_object(cfg) assert type(nested) is module.NestedWithAny assert type(nested.var) is module.Nested assert nested.var.with_default == 10 assert nested.var.mandatory_missing == 123 + assert nested.var.interpolation == 456 def test_str2user_instantiate(self, module: Any) -> None: cfg = OmegaConf.structured(module.DictSubclass.Str2User()) From eb41a377e809a6948083cf5b60c2a1c8662d3002 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Tue, 16 Mar 2021 19:09:14 -0500 Subject: [PATCH 60/85] Docs example for SCMode.INSTANTIATE --- docs/source/usage.rst | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 5850c3c9c..a929484fb 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -757,6 +757,18 @@ as DictConfig, allowing attribute style access on the resulting node. >>> assert type(container["structured_config"]) is DictConfig >>> assert container["structured_config"].port == 80 +Using **structured_config_mode=SCMode.INSTANTIATE** causes such nodes to +be converted to instances of the backing dataclass. + +.. doctest:: + + >>> conf = OmegaConf.create({"structured_config": MyConfig}) + >>> container = OmegaConf.to_container(conf, structured_config_mode=SCMode.INSTANTIATE) + >>> print(container) + {'structured_config': MyConfig(port=80, host='localhost')} + >>> assert type(container["structured_config"]) is MyConfig + >>> assert container["structured_config"].port == 80 + OmegaConf.select ^^^^^^^^^^^^^^^^ OmegaConf.select() allow you to select a config node or value using a dot-notation key. From 30550bfc74b2e1fc604921946626b36b0317c209 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Tue, 16 Mar 2021 19:14:04 -0500 Subject: [PATCH 61/85] docs: OmegaConf.to_object example --- docs/source/usage.rst | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/docs/source/usage.rst b/docs/source/usage.rst index a929484fb..bbee19b8f 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -762,13 +762,24 @@ be converted to instances of the backing dataclass. .. doctest:: - >>> conf = OmegaConf.create({"structured_config": MyConfig}) >>> container = OmegaConf.to_container(conf, structured_config_mode=SCMode.INSTANTIATE) >>> print(container) {'structured_config': MyConfig(port=80, host='localhost')} >>> assert type(container["structured_config"]) is MyConfig >>> assert container["structured_config"].port == 80 +The `OmegaConf.to_object` method provides a convenient alias to achieve the above: + +.. doctest:: + + >>> container = OmegaConf.to_object(conf) + >>> print(container) + {'structured_config': MyConfig(port=80, host='localhost')} + +Calling `OmegaConf.to_object(conf)` is equivalent to +`OmegaConf.to_container(conf, resolve=True, structured_config_mode=SCMode.INSTANTIATE)`; +string interpolations will be resolved before dataclass instances are instantiated. + OmegaConf.select ^^^^^^^^^^^^^^^^ OmegaConf.select() allow you to select a config node or value using a dot-notation key. From 0611c93b05311ba6d056ac27e2fdfb8706dc67e1 Mon Sep 17 00:00:00 2001 From: Jasha10 <8935917+Jasha10@users.noreply.github.com> Date: Tue, 16 Mar 2021 19:16:05 -0500 Subject: [PATCH 62/85] Docs minor edit --- docs/source/usage.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/usage.rst b/docs/source/usage.rst index bbee19b8f..9c305e9e2 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -778,7 +778,7 @@ The `OmegaConf.to_object` method provides a convenient alias to achieve the abov Calling `OmegaConf.to_object(conf)` is equivalent to `OmegaConf.to_container(conf, resolve=True, structured_config_mode=SCMode.INSTANTIATE)`; -string interpolations will be resolved before dataclass instances are instantiated. +string interpolations will be resolved before dataclass instances are created. OmegaConf.select ^^^^^^^^^^^^^^^^ From 32f6c68c6dde51365bc15d94fa99e54388bbfa41 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Wed, 17 Mar 2021 00:30:17 -0500 Subject: [PATCH 63/85] updates to to_object docs --- docs/source/usage.rst | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 9c305e9e2..eea3d74b4 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -757,28 +757,31 @@ as DictConfig, allowing attribute style access on the resulting node. >>> assert type(container["structured_config"]) is DictConfig >>> assert container["structured_config"].port == 80 -Using **structured_config_mode=SCMode.INSTANTIATE** causes such nodes to -be converted to instances of the backing dataclass. +OmegaConf.to_object +^^^^^^^^^^^^^^^^^^^^^^ +The ``OmegaConf.to_object`` method is very similar to the +``OmegaConf.to_container`` method. .. doctest:: - >>> container = OmegaConf.to_container(conf, structured_config_mode=SCMode.INSTANTIATE) + >>> container = OmegaConf.to_object(conf) >>> print(container) {'structured_config': MyConfig(port=80, host='localhost')} >>> assert type(container["structured_config"]) is MyConfig >>> assert container["structured_config"].port == 80 -The `OmegaConf.to_object` method provides a convenient alias to achieve the above: - -.. doctest:: - - >>> container = OmegaConf.to_object(conf) - >>> print(container) - {'structured_config': MyConfig(port=80, host='localhost')} - -Calling `OmegaConf.to_object(conf)` is equivalent to -`OmegaConf.to_container(conf, resolve=True, structured_config_mode=SCMode.INSTANTIATE)`; -string interpolations will be resolved before dataclass instances are created. +Note that here, ``container.structured_config`` is actually an instance of +``MyConfig``, whereas in the previous examples we had a ``dict`` or +``DictConfig`` object that was duck-typed to look like an instance of +``MyConfig``. + +The call ``OmegaConf.to_object(conf)`` is equivalent to +``OmegaConf.to_container(conf, resolve=True, +structured_config_mode=SCMode.INSTANTIATE)``. The ``resolve=True`` keyword +argument means that string interpolations are resolved before conversion to a +container, and ``structured_config_mode=SCMode.INSTANTIATE`` means that each +structured config node is converted into an actual instance of the underlying +structured config type. OmegaConf.select ^^^^^^^^^^^^^^^^ From 1019df6dac36e827789a6a3040557c10f5d3cf1d Mon Sep 17 00:00:00 2001 From: Jasha10 <8935917+Jasha10@users.noreply.github.com> Date: Thu, 18 Mar 2021 11:40:11 -0500 Subject: [PATCH 64/85] Revert test_structured_config.py (remove redundant test) --- .../structured_conf/test_structured_config.py | 20 ------------------- 1 file changed, 20 deletions(-) diff --git a/tests/structured_conf/test_structured_config.py b/tests/structured_conf/test_structured_config.py index 258a7bfe2..f4cda14f6 100644 --- a/tests/structured_conf/test_structured_config.py +++ b/tests/structured_conf/test_structured_config.py @@ -1010,26 +1010,6 @@ def test_str2user(self, module: Any) -> None: # bad key cfg[Color.BLUE] = "nope" - def test_str2user_with_field(self, module: Any) -> None: - cfg = OmegaConf.structured(module.DictSubclass.Str2UserWithField()) - - assert cfg.foo.name == "Bond" - assert cfg.foo.age == 7 - assert isinstance(cfg.foo, DictConfig) - - cfg.mp = module.User(name="Moneypenny", age=11) - assert cfg.mp.name == "Moneypenny" - assert cfg.mp.age == 11 - assert isinstance(cfg.mp, DictConfig) - - with raises(ValidationError): - # bad value - cfg.hello = "world" - - with raises(KeyValidationError): - # bad key - cfg[Color.BLUE] = cfg.mp - def test_str2str_with_field(self, module: Any) -> None: cfg = OmegaConf.structured(module.DictSubclass.Str2StrWithField()) assert cfg.foo == "bar" From 3fef7f0563bb1b304bf625e0a54a5dd42b3f9cfd Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Thu, 18 Mar 2021 16:56:00 -0500 Subject: [PATCH 65/85] dict subclass: DictConfig items become instance attributes --- omegaconf/basecontainer.py | 9 ++------- tests/test_to_container.py | 10 +++++----- 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/omegaconf/basecontainer.py b/omegaconf/basecontainer.py index 54c26fb0a..da69acf75 100644 --- a/omegaconf/basecontainer.py +++ b/omegaconf/basecontainer.py @@ -285,13 +285,8 @@ def _instantiate_structured_config_impl(self, instance_data: Dict[str, Any]) -> nonfield_items[k] = v result = object_type(**field_items) - if not issubclass(object_type, dict): - # normal structured config - for k, v in nonfield_items.items(): - setattr(result, k, v) - else: - # Extending dict as a subclass - result.update(nonfield_items) + for k, v in nonfield_items.items(): + setattr(result, k, v) return result def pretty(self, resolve: bool = False, sort_keys: bool = False) -> str: diff --git a/tests/test_to_container.py b/tests/test_to_container.py index 439a9047e..f220dca07 100644 --- a/tests/test_to_container.py +++ b/tests/test_to_container.py @@ -339,8 +339,8 @@ def test_str2user_instantiate(self, module: Any) -> None: data = self.round_trip_to_object(cfg) assert type(data) is module.DictSubclass.Str2User - assert type(data["bond"]) is module.User - assert data["bond"] == module.User("James Bond", 7) + assert type(data.bond) is module.User + assert data.bond == module.User("James Bond", 7) def test_str2user_with_field_instantiate(self, module: Any) -> None: cfg = OmegaConf.structured(module.DictSubclass.Str2UserWithField()) @@ -350,8 +350,8 @@ def test_str2user_with_field_instantiate(self, module: Any) -> None: assert type(data) is module.DictSubclass.Str2UserWithField assert type(data.foo) is module.User assert data.foo == module.User("Bond", 7) - assert type(data["mp"]) is module.User - assert data["mp"] == module.User("Moneypenny", 11) + assert type(data.mp) is module.User + assert data.mp == module.User("Moneypenny", 11) def test_str2str_with_field_instantiate(self, module: Any) -> None: cfg = OmegaConf.structured(module.DictSubclass.Str2StrWithField()) @@ -360,7 +360,7 @@ def test_str2str_with_field_instantiate(self, module: Any) -> None: assert type(data) is module.DictSubclass.Str2StrWithField assert data.foo == "bar" - assert data["hello"] == "world" + assert data.hello == "world" def test_setattr_for_user_with_extra_field(self, module: Any) -> None: cfg = OmegaConf.structured(module.User(name="James Bond", age=7)) From 15a03eab50cade53b7469cfbc4ccacb6c6c8fd2c Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Thu, 18 Mar 2021 20:01:30 -0500 Subject: [PATCH 66/85] docs: use `show` instead of `print`/`assert` --- docs/source/usage.rst | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/usage.rst b/docs/source/usage.rst index b0715f165..68505194b 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -814,10 +814,10 @@ The ``OmegaConf.to_object`` method is very similar to the .. doctest:: >>> container = OmegaConf.to_object(conf) - >>> print(container) - {'structured_config': MyConfig(port=80, host='localhost')} - >>> assert type(container["structured_config"]) is MyConfig - >>> assert container["structured_config"].port == 80 + >>> show(container) + type: dict, value: {'structured_config': MyConfig(port=80, host='localhost')} + >>> show(container["structured_config"]) + type: MyConfig, value: MyConfig(port=80, host='localhost') Note that here, ``container.structured_config`` is actually an instance of ``MyConfig``, whereas in the previous examples we had a ``dict`` or From f3171f243d5e375563d78c216b3d103948e7c411 Mon Sep 17 00:00:00 2001 From: Jasha10 <8935917+Jasha10@users.noreply.github.com> Date: Fri, 19 Mar 2021 16:06:08 -0500 Subject: [PATCH 67/85] minor doc fix Co-authored-by: Olivier Delalleau <507137+odelalleau@users.noreply.github.com> --- docs/source/usage.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 68505194b..0be886368 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -819,7 +819,7 @@ The ``OmegaConf.to_object`` method is very similar to the >>> show(container["structured_config"]) type: MyConfig, value: MyConfig(port=80, host='localhost') -Note that here, ``container.structured_config`` is actually an instance of +Note that here, ``container["structured_config"]`` is actually an instance of ``MyConfig``, whereas in the previous examples we had a ``dict`` or ``DictConfig`` object that was duck-typed to look like an instance of ``MyConfig``. From 80284f4fc3209845fc380e40203973af0d3e0beb Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Mon, 29 Mar 2021 13:06:43 -0500 Subject: [PATCH 68/85] docs: Improve introduction to `to_object` method --- docs/source/usage.rst | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 0be886368..a1a75b08c 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -808,8 +808,10 @@ as DictConfig, allowing attribute style access on the resulting node. OmegaConf.to_object ^^^^^^^^^^^^^^^^^^^^^^ -The ``OmegaConf.to_object`` method is very similar to the -``OmegaConf.to_container`` method. +The ``OmegaConf.to_object`` method recursively converts DictConfig and ListConfig objects +into dicts and lists, with the execption that Structured Config objects are +converted into instances of the backing dataclass or attr class. All OmegaConf +interpolations are resolved before conversion to python containers. .. doctest:: From 8f17b9a45739d748ee3028b7f198e0b4317786c5 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Mon, 29 Mar 2021 13:09:02 -0500 Subject: [PATCH 69/85] docs: Remove explanation r.e. equivalent OmegaConf.to_container calls --- docs/source/usage.rst | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/docs/source/usage.rst b/docs/source/usage.rst index a1a75b08c..59f5f1c99 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -828,11 +828,7 @@ Note that here, ``container["structured_config"]`` is actually an instance of The call ``OmegaConf.to_object(conf)`` is equivalent to ``OmegaConf.to_container(conf, resolve=True, -structured_config_mode=SCMode.INSTANTIATE)``. The ``resolve=True`` keyword -argument means that string interpolations are resolved before conversion to a -container, and ``structured_config_mode=SCMode.INSTANTIATE`` means that each -structured config node is converted into an actual instance of the underlying -structured config type. +structured_config_mode=SCMode.INSTANTIATE)``. OmegaConf.select ^^^^^^^^^^^^^^^^ From e1e034aa7147b54610d4d093dfd20d5f8d75e1ee Mon Sep 17 00:00:00 2001 From: Jasha10 <8935917+Jasha10@users.noreply.github.com> Date: Mon, 29 Mar 2021 13:12:46 -0500 Subject: [PATCH 70/85] docs: clarification on ducktyping Co-authored-by: Omry Yadan --- docs/source/usage.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 59f5f1c99..b99804990 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -822,7 +822,7 @@ interpolations are resolved before conversion to python containers. type: MyConfig, value: MyConfig(port=80, host='localhost') Note that here, ``container["structured_config"]`` is actually an instance of -``MyConfig``, whereas in the previous examples we had a ``dict`` or +``MyConfig``, whereas in the previous examples we had a ``dict`` or a ``DictConfig`` object that was duck-typed to look like an instance of ``MyConfig``. From 29323a4807789d35bf089b0cf06f247cf7fde541 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Mon, 29 Mar 2021 13:46:54 -0500 Subject: [PATCH 71/85] to_container docs: explicitly document the new SCMode.INSTANTIATE member --- docs/source/usage.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 75fd97df8..b6fe614fd 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -776,6 +776,11 @@ Structured Config nodes using the ``structured_config_mode`` option. By default, Structured Config nodes are converted to plain dict. Using ``structured_config_mode=SCMode.DICT_CONFIG`` causes such nodes to remain as DictConfig, allowing attribute style access on the resulting node. +Using ``structured_config_mode=SCMode.INSTANTIATE``, Structured Config nodes +are converted to instances of the backing dataclass or attrs class. Note that +typically ``structured_config_mode=SCMode.INSTANTIATE`` makes the most sense +when combined with ``resolve=True``, so that interpolations are resolved before +being using to instantiate dataclass/attr class instances. .. doctest:: From fe5df1d6202ff18f92d3352be3c8d6e7bcd8ff80 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Mon, 29 Mar 2021 14:02:29 -0500 Subject: [PATCH 72/85] update `to_object` docstring --- omegaconf/omegaconf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/omegaconf/omegaconf.py b/omegaconf/omegaconf.py index 6f060b15b..d49195a5c 100644 --- a/omegaconf/omegaconf.py +++ b/omegaconf/omegaconf.py @@ -608,7 +608,7 @@ def to_object( Any DictConfig objects backed by dataclasses or attrs classes are instantiated as instances of those backing classes. - This is an alias for OmegaConf.to_container(..., structured_config_mode=SCMode.INSTANTIATE) + This is an alias for OmegaConf.to_container(..., resolve=True, structured_config_mode=SCMode.INSTANTIATE) :param cfg: the config to convert :param enum_to_str: True to convert Enum values to strings From a9a05ee47eca3c35c4db98f6246914522908c644 Mon Sep 17 00:00:00 2001 From: Jasha10 <8935917+Jasha10@users.noreply.github.com> Date: Mon, 29 Mar 2021 14:51:25 -0500 Subject: [PATCH 73/85] docs: fix typos Co-authored-by: Olivier Delalleau <507137+odelalleau@users.noreply.github.com> --- docs/source/usage.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/usage.rst b/docs/source/usage.rst index b6fe614fd..71b09d5b1 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -780,7 +780,7 @@ Using ``structured_config_mode=SCMode.INSTANTIATE``, Structured Config nodes are converted to instances of the backing dataclass or attrs class. Note that typically ``structured_config_mode=SCMode.INSTANTIATE`` makes the most sense when combined with ``resolve=True``, so that interpolations are resolved before -being using to instantiate dataclass/attr class instances. +being used to instantiate dataclass/attr class instances. .. doctest:: @@ -796,9 +796,9 @@ being using to instantiate dataclass/attr class instances. OmegaConf.to_object ^^^^^^^^^^^^^^^^^^^^^^ The ``OmegaConf.to_object`` method recursively converts DictConfig and ListConfig objects -into dicts and lists, with the execption that Structured Config objects are +into dicts and lists, with the exception that Structured Config objects are converted into instances of the backing dataclass or attr class. All OmegaConf -interpolations are resolved before conversion to python containers. +interpolations are resolved before conversion to Python containers. .. doctest:: From db098803c5fa1cea3f5a973e022f5d2b07a1f6df Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Wed, 31 Mar 2021 13:49:39 -0500 Subject: [PATCH 74/85] empty commit (to trigger CI workflow) From a17a11fcd1ebbfb5a5d42c47deea0e060dbd18a3 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Wed, 31 Mar 2021 21:39:18 -0500 Subject: [PATCH 75/85] refactor test_SCMode --- tests/test_to_container.py | 102 ++++++++++++++++++++++++------------- 1 file changed, 66 insertions(+), 36 deletions(-) diff --git a/tests/test_to_container.py b/tests/test_to_container.py index 1d97b0672..86a8d581d 100644 --- a/tests/test_to_container.py +++ b/tests/test_to_container.py @@ -1,7 +1,7 @@ import re from enum import Enum from importlib import import_module -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from pytest import fixture, mark, param, raises @@ -48,58 +48,88 @@ def assert_container_with_primitives(item: Any) -> None: @mark.parametrize( - "src,ex_dict,ex_dict_config,ex_instantiate,key", + "structured_config_mode,src,expected,key,expected_value_type", [ param( + SCMode.DICT, {"user": User(age=7, name="Bond")}, {"user": {"name": "Bond", "age": 7}}, + "user", + dict, + id="DICT-dict", + ), + param( + SCMode.DICT, + [1, User(age=7, name="Bond")], + [1, {"name": "Bond", "age": 7}], + 1, + dict, + id="DICT-list", + ), + param( + SCMode.DICT_CONFIG, {"user": User(age=7, name="Bond")}, {"user": User(age=7, name="Bond")}, "user", - id="structured-inside-dict", + DictConfig, + id="DICT_CONFIG-dict", ), param( + SCMode.DICT_CONFIG, [1, User(age=7, name="Bond")], - [1, {"name": "Bond", "age": 7}], + [1, User(age=7, name="Bond")], + 1, + DictConfig, + id="DICT_CONFIG-list", + ), + param( + SCMode.INSTANTIATE, + {"user": User(age=7, name="Bond")}, + {"user": User(age=7, name="Bond")}, + "user", + User, + id="INSTANTIATE-dict", + ), + param( + SCMode.INSTANTIATE, [1, User(age=7, name="Bond")], [1, User(age=7, name="Bond")], 1, - id="structured-inside-list", + User, + id="INSTANTIATE-list", + ), + param( + None, + {"user": User(age=7, name="Bond")}, + {"user": {"name": "Bond", "age": 7}}, + "user", + dict, + id="default-dict", + ), + param( + None, + [1, User(age=7, name="Bond")], + [1, {"name": "Bond", "age": 7}], + 1, + dict, + id="default-list", ), ], ) -class TestSCMode: - @fixture - def cfg(self, src: Any) -> Any: - return OmegaConf.create(src) - - def test_exclude_structured_configs_default( - self, cfg: Any, ex_dict: Any, ex_dict_config: Any, ex_instantiate: Any, key: Any - ) -> None: +def test_SCMode( + src: Any, + structured_config_mode: Optional[SCMode], + expected: Any, + expected_value_type: Any, + key: Any, +) -> None: + cfg = OmegaConf.create(src) + if structured_config_mode is None: ret = OmegaConf.to_container(cfg) - assert ret == ex_dict - assert isinstance(ret[key], dict) - - def test_scmode_dict( - self, cfg: Any, ex_dict: Any, ex_dict_config: Any, ex_instantiate: Any, key: Any - ) -> None: - ret = OmegaConf.to_container(cfg, structured_config_mode=SCMode.DICT) - assert ret == ex_dict - assert isinstance(ret[key], dict) - - def test_scmode_dict_config( - self, cfg: Any, ex_dict: Any, ex_dict_config: Any, ex_instantiate: Any, key: Any - ) -> None: - ret = OmegaConf.to_container(cfg, structured_config_mode=SCMode.DICT_CONFIG) - assert ret == ex_dict_config - assert isinstance(ret[key], DictConfig) - - def test_scmode_instantiate( - self, cfg: Any, ex_dict: Any, ex_dict_config: Any, ex_instantiate: Any, key: Any - ) -> None: - ret = OmegaConf.to_container(cfg, structured_config_mode=SCMode.INSTANTIATE) - assert ret == ex_instantiate - assert isinstance(ret[key], User) + else: + ret = OmegaConf.to_container(cfg, structured_config_mode=structured_config_mode) + assert ret == expected + assert isinstance(ret[key], expected_value_type) @mark.parametrize( From 29a526b032a862295e6b1da4ca178c848f8d2872 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Thu, 1 Apr 2021 13:54:33 -0500 Subject: [PATCH 76/85] lowercase test fn name (test_SCMode -> test_scmode) --- tests/test_to_container.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_to_container.py b/tests/test_to_container.py index 86a8d581d..2fabb6baf 100644 --- a/tests/test_to_container.py +++ b/tests/test_to_container.py @@ -116,7 +116,7 @@ def assert_container_with_primitives(item: Any) -> None: ), ], ) -def test_SCMode( +def test_scmode( src: Any, structured_config_mode: Optional[SCMode], expected: Any, From c672c100f9bfd07f1e753d448118603abd71eea2 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Fri, 2 Apr 2021 16:02:57 -0500 Subject: [PATCH 77/85] StructuredConfigs have resolve=True and enum_to_str=False --- docs/source/usage.rst | 12 ++++-- omegaconf/basecontainer.py | 44 ++++--------------- omegaconf/dictconfig.py | 49 +++++++++++++++++++++- tests/structured_conf/data/attr_classes.py | 8 ++++ tests/structured_conf/data/dataclasses.py | 8 ++++ tests/test_to_container.py | 37 ++++++++++++---- 6 files changed, 110 insertions(+), 48 deletions(-) diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 71b09d5b1..063d90a11 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -778,9 +778,15 @@ Using ``structured_config_mode=SCMode.DICT_CONFIG`` causes such nodes to remain as DictConfig, allowing attribute style access on the resulting node. Using ``structured_config_mode=SCMode.INSTANTIATE``, Structured Config nodes are converted to instances of the backing dataclass or attrs class. Note that -typically ``structured_config_mode=SCMode.INSTANTIATE`` makes the most sense -when combined with ``resolve=True``, so that interpolations are resolved before -being used to instantiate dataclass/attr class instances. +when ``structured_config_mode=SCMode.INSTANTIATE``, interpolations nested within +a structured config node will be resolved, even if ``OmegaConf.to_container`` is called +with the the keyword argument ``resolve=False``, so that interpolations are resolved before +being used to instantiate dataclass/attr class instances. Interpolations within +non-structured parent nodes will be resolved (or not) as usual, according to +the ``resolve`` keyword arg. +Similarly, when ``structured_config_mode=SCMode.INSTANTIATE``, enum values nested within a +structured config node will not be converted to ``str``, even if ``OmegaConf.to_container`` +is called with ``enum_to_str=True``. .. doctest:: diff --git a/omegaconf/basecontainer.py b/omegaconf/basecontainer.py index 0634af88f..d5372af64 100644 --- a/omegaconf/basecontainer.py +++ b/omegaconf/basecontainer.py @@ -213,8 +213,12 @@ def convert(val: Node) -> Any: and structured_config_mode == SCMode.DICT_CONFIG ): return conf + if structured_config_mode == SCMode.INSTANTIATE and is_structured_config( + conf._metadata.object_type + ): + return conf._to_object() - ret: Any = {} + retdict: Dict[str, Any] = {} for key in conf.keys(): node = conf._get_node(key) assert isinstance(node, Node) @@ -225,20 +229,16 @@ def convert(val: Node) -> Any: if enum_to_str and isinstance(key, Enum): key = f"{key.name}" if isinstance(node, Container): - ret[key] = BaseContainer._to_content( + retdict[key] = BaseContainer._to_content( node, resolve=resolve, enum_to_str=enum_to_str, structured_config_mode=structured_config_mode, ) else: - ret[key] = convert(node) + retdict[key] = convert(node) - if structured_config_mode == SCMode.INSTANTIATE and is_structured_config( - conf._metadata.object_type - ): - ret = conf._instantiate_structured_config_impl(instance_data=ret) - return ret + return retdict elif isinstance(conf, ListConfig): retlist: List[Any] = [] for index in range(len(conf)): @@ -262,34 +262,6 @@ def convert(val: Node) -> Any: assert False - def _instantiate_structured_config_impl(self, instance_data: Dict[str, Any]) -> Any: - """Instantiate an instance of `self._metadata.object_type`, populated by `instance_data`.""" - from ._utils import get_structured_config_field_names - - object_type = self._metadata.object_type - object_type_field_names = set(get_structured_config_field_names(object_type)) - - field_items: Dict[str, Any] = {} - nonfield_items: Dict[str, Any] = {} - for k, v in instance_data.items(): - if _is_missing_literal(v): - self._format_and_raise( - key=k, - value=None, - cause=MissingMandatoryValue( - "Structured config of type `$OBJECT_TYPE` has missing mandatory value: $KEY" - ), - ) - if k in object_type_field_names: - field_items[k] = v - else: - nonfield_items[k] = v - - result = object_type(**field_items) - for k, v in nonfield_items.items(): - setattr(result, k, v) - return result - def pretty(self, resolve: bool = False, sort_keys: bool = False) -> str: from omegaconf import OmegaConf diff --git a/omegaconf/dictconfig.py b/omegaconf/dictconfig.py index 9f418d57f..d00798da2 100644 --- a/omegaconf/dictconfig.py +++ b/omegaconf/dictconfig.py @@ -20,6 +20,7 @@ ValueKind, _get_value, _is_interpolation, + _is_missing_literal, _is_missing_value, _is_none, _valid_dict_key_annotation_type, @@ -35,7 +36,7 @@ type_str, valid_value_annotation_type, ) -from .base import Container, ContainerMetadata, DictKeyType, Node +from .base import Container, ContainerMetadata, DictKeyType, Node, SCMode from .basecontainer import BaseContainer from .errors import ( ConfigAttributeError, @@ -682,3 +683,49 @@ def _dict_conf_eq(d1: "DictConfig", d2: "DictConfig") -> bool: return False return True + + def _to_object(self) -> Any: + """Instantiate an instance of `self._metadata.object_type`. + This requires `self` to be a structured config. + Nested subconfigs are converted to_container with resolve=True.""" + from ._utils import get_structured_config_field_names + + object_type = self._metadata.object_type + assert is_structured_config(object_type) + object_type_field_names = set(get_structured_config_field_names(object_type)) + + field_items: Dict[str, Any] = {} + nonfield_items: Dict[str, Any] = {} + for k in self.keys(): + node = self._get_node(k) + assert isinstance(node, Node) + node = node._dereference_node(throw_on_resolution_failure=True) + assert node is not None + if isinstance(node, Container): + v = BaseContainer._to_content( + node, + resolve=True, + enum_to_str=False, + structured_config_mode=SCMode.INSTANTIATE, + # TODO: throw_on_missing=True, + ) + else: + v = node._value() + + if _is_missing_literal(v): + self._format_and_raise( + key=k, + value=None, + cause=MissingMandatoryValue( + "Structured config of type `$OBJECT_TYPE` has missing mandatory value: $KEY" + ), + ) + if k in object_type_field_names: + field_items[k] = v + else: + nonfield_items[k] = v + + result = object_type(**field_items) + for k, v in nonfield_items.items(): + setattr(result, k, v) + return result diff --git a/tests/structured_conf/data/attr_classes.py b/tests/structured_conf/data/attr_classes.py index c3b1df291..fc7cf3d5c 100644 --- a/tests/structured_conf/data/attr_classes.py +++ b/tests/structured_conf/data/attr_classes.py @@ -238,6 +238,14 @@ class Interpolation: z2: str = SI("${x}_${y}") +@attr.s(auto_attribs=True) +class RelativeInterpolation: + x: int = 100 + y: int = 200 + z1: int = II(".x") + z2: str = SI("${.x}_${.y}") + + @attr.s(auto_attribs=True) class BoolOptional: with_default: Optional[bool] = True diff --git a/tests/structured_conf/data/dataclasses.py b/tests/structured_conf/data/dataclasses.py index eecd52f87..ef5ccf69d 100644 --- a/tests/structured_conf/data/dataclasses.py +++ b/tests/structured_conf/data/dataclasses.py @@ -239,6 +239,14 @@ class Interpolation: z2: str = SI("${x}_${y}") +@dataclass +class RelativeInterpolation: + x: int = 100 + y: int = 200 + z1: int = II(".x") + z2: str = SI("${.x}_${.y}") + + @dataclass class BoolOptional: with_default: Optional[bool] = True diff --git a/tests/test_to_container.py b/tests/test_to_container.py index 2fabb6baf..fee18be83 100644 --- a/tests/test_to_container.py +++ b/tests/test_to_container.py @@ -328,19 +328,40 @@ def test_to_object_resolve_is_True_by_default(self, module: Any) -> None: def test_to_container_INSTANTIATE_resolve_False(self, module: Any) -> None: """Test the lower level `to_container` API with SCMode.INSTANTIATE and resolve=False""" - serialized = OmegaConf.structured(module.Interpolation) - interp = OmegaConf.to_container( - serialized, resolve=False, structured_config_mode=SCMode.INSTANTIATE + src = dict( + obj=module.RelativeInterpolation(), + interp_x="${obj.x}", + interp_x_y="${obj.x}_${obj.x}", ) - assert isinstance(interp, module.Interpolation) - assert type(interp) is module.Interpolation + nested = OmegaConf.create(src) + container = OmegaConf.to_container( + nested, resolve=False, structured_config_mode=SCMode.INSTANTIATE + ) + assert isinstance(container, dict) + assert container["interp_x"] == "${obj.x}" + assert container["interp_x_y"] == "${obj.x}_${obj.x}" + assert container["obj"].z1 == 100 + assert container["obj"].z2 == "100_200" - assert interp.z1 == "${x}" - assert interp.z2 == "${x}_${y}" + def test_to_container_INSTANTIATE_enum_to_str_True(self, module: Any) -> None: + """Test the lower level `to_container` API with SCMode.INSTANTIATE and resolve=False""" + src = dict( + color=Color.BLUE, + obj=module.EnumOptional(), + ) + nested = OmegaConf.create(src) + container = OmegaConf.to_container( + nested, enum_to_str=True, structured_config_mode=SCMode.INSTANTIATE + ) + assert isinstance(container, dict) + assert container["color"] == "BLUE" + assert container["obj"].not_optional is Color.BLUE def test_to_object_InterpolationResolutionError(self, module: Any) -> None: with raises(InterpolationResolutionError): - self.round_trip_to_object(module.NestedWithAny) + cfg = OmegaConf.structured(module.NestedWithAny) + cfg.var.mandatory_missing = 123 + OmegaConf.to_object(cfg) def test_nested_object_with_Any_ref_type(self, module: Any) -> None: cfg = OmegaConf.structured(module.NestedWithAny()) From 672b1809d1b082097cc46d8ff1263f2e665bbbec Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Fri, 2 Apr 2021 17:37:38 -0500 Subject: [PATCH 78/85] minor: revert whitespace addition --- omegaconf/basecontainer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/omegaconf/basecontainer.py b/omegaconf/basecontainer.py index d5372af64..3446b31a8 100644 --- a/omegaconf/basecontainer.py +++ b/omegaconf/basecontainer.py @@ -237,7 +237,6 @@ def convert(val: Node) -> Any: ) else: retdict[key] = convert(node) - return retdict elif isinstance(conf, ListConfig): retlist: List[Any] = [] From 8beb52ae255e892e5ef9e20bc19d8b7acea56707 Mon Sep 17 00:00:00 2001 From: Jasha10 <8935917+Jasha10@users.noreply.github.com> Date: Tue, 6 Apr 2021 16:09:43 -0500 Subject: [PATCH 79/85] Edit to news/472.feature Co-authored-by: Omry Yadan --- news/472.feature | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/news/472.feature b/news/472.feature index 6b05290d1..c51b70e8d 100644 --- a/news/472.feature +++ b/news/472.feature @@ -1 +1 @@ -Add the OmegaConf.to_object method, which converts Structured Config objects back to native dataclasses or attrs classes. +Add the OmegaConf.to_object method, which converts Structured Configs to native instances of the underlying `@dataclass` or `@attr.s` class. From 4787e8d34cc866e0a99e8744c5cedf5686c4a683 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Wed, 7 Apr 2021 13:45:56 -0500 Subject: [PATCH 80/85] don't mention enum_to_str --- docs/source/usage.rst | 3 --- 1 file changed, 3 deletions(-) diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 063d90a11..772a2dd63 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -784,9 +784,6 @@ with the the keyword argument ``resolve=False``, so that interpolations are reso being used to instantiate dataclass/attr class instances. Interpolations within non-structured parent nodes will be resolved (or not) as usual, according to the ``resolve`` keyword arg. -Similarly, when ``structured_config_mode=SCMode.INSTANTIATE``, enum values nested within a -structured config node will not be converted to ``str``, even if ``OmegaConf.to_container`` -is called with ``enum_to_str=True``. .. doctest:: From c1d13f8a5419fe4c5aa5be714b742b996bf671a1 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Wed, 7 Apr 2021 13:56:26 -0500 Subject: [PATCH 81/85] formatting and title for structured_config_mode docs --- docs/source/usage.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 772a2dd63..54cdd7b6e 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -771,11 +771,16 @@ If resolve is set to True, interpolations will be resolved during conversion. >>> show(resolved) type: dict, value: {'foo': 'bar', 'foo2': 'bar'} + +Using ``structured_config_mode`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ You can customize the treatment of ``OmegaConf.to_container()`` for Structured Config nodes using the ``structured_config_mode`` option. By default, Structured Config nodes are converted to plain dict. + Using ``structured_config_mode=SCMode.DICT_CONFIG`` causes such nodes to remain as DictConfig, allowing attribute style access on the resulting node. + Using ``structured_config_mode=SCMode.INSTANTIATE``, Structured Config nodes are converted to instances of the backing dataclass or attrs class. Note that when ``structured_config_mode=SCMode.INSTANTIATE``, interpolations nested within From bc2f610322b3f55833cc3a115862a8e51af0efd3 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Wed, 7 Apr 2021 13:58:06 -0500 Subject: [PATCH 82/85] remove TODO comment --- omegaconf/dictconfig.py | 1 - 1 file changed, 1 deletion(-) diff --git a/omegaconf/dictconfig.py b/omegaconf/dictconfig.py index d00798da2..97be5d1c1 100644 --- a/omegaconf/dictconfig.py +++ b/omegaconf/dictconfig.py @@ -707,7 +707,6 @@ def _to_object(self) -> Any: resolve=True, enum_to_str=False, structured_config_mode=SCMode.INSTANTIATE, - # TODO: throw_on_missing=True, ) else: v = node._value() From f1a4270bdae1c8f6fc219e8afe5a33c098892823 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Wed, 7 Apr 2021 14:00:11 -0500 Subject: [PATCH 83/85] fix comment formatting --- omegaconf/dictconfig.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/omegaconf/dictconfig.py b/omegaconf/dictconfig.py index 97be5d1c1..5d7369ce5 100644 --- a/omegaconf/dictconfig.py +++ b/omegaconf/dictconfig.py @@ -687,7 +687,8 @@ def _dict_conf_eq(d1: "DictConfig", d2: "DictConfig") -> bool: def _to_object(self) -> Any: """Instantiate an instance of `self._metadata.object_type`. This requires `self` to be a structured config. - Nested subconfigs are converted to_container with resolve=True.""" + Nested subconfigs are converted to_container with resolve=True. + """ from ._utils import get_structured_config_field_names object_type = self._metadata.object_type From b51d33f1135197cb1e204d2588b997f9e2b4f46c Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Wed, 7 Apr 2021 14:44:41 -0500 Subject: [PATCH 84/85] move `import get_structured_config_field_names` to top of file --- omegaconf/dictconfig.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/omegaconf/dictconfig.py b/omegaconf/dictconfig.py index 5d7369ce5..96134832b 100644 --- a/omegaconf/dictconfig.py +++ b/omegaconf/dictconfig.py @@ -26,6 +26,7 @@ _valid_dict_key_annotation_type, format_and_raise, get_structured_config_data, + get_structured_config_field_names, get_type_of, get_value_kind, is_container_annotation, @@ -689,8 +690,6 @@ def _to_object(self) -> Any: This requires `self` to be a structured config. Nested subconfigs are converted to_container with resolve=True. """ - from ._utils import get_structured_config_field_names - object_type = self._metadata.object_type assert is_structured_config(object_type) object_type_field_names = set(get_structured_config_field_names(object_type)) From d12701e9add8ee7084b9ca8598870e441e974d6c Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Wed, 7 Apr 2021 17:46:28 -0500 Subject: [PATCH 85/85] one last formatting adjustment --- omegaconf/dictconfig.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/omegaconf/dictconfig.py b/omegaconf/dictconfig.py index 96134832b..b868bcb7a 100644 --- a/omegaconf/dictconfig.py +++ b/omegaconf/dictconfig.py @@ -686,7 +686,8 @@ def _dict_conf_eq(d1: "DictConfig", d2: "DictConfig") -> bool: return True def _to_object(self) -> Any: - """Instantiate an instance of `self._metadata.object_type`. + """ + Instantiate an instance of `self._metadata.object_type`. This requires `self` to be a structured config. Nested subconfigs are converted to_container with resolve=True. """