From b258b77a7abca1b7db400ed482c43a6bcfeff34a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pere=20D=C3=ADaz?= Date: Fri, 6 Nov 2020 17:49:29 +0100 Subject: [PATCH 1/6] fix and tests --- omegaconf/basecontainer.py | 4 ++-- tests/__init__.py | 10 ++++++++++ tests/test_merge.py | 22 ++++++++++++++++++++++ 3 files changed, 34 insertions(+), 2 deletions(-) diff --git a/omegaconf/basecontainer.py b/omegaconf/basecontainer.py index 05d1d2f02..e34db8383 100644 --- a/omegaconf/basecontainer.py +++ b/omegaconf/basecontainer.py @@ -298,7 +298,7 @@ def expand(node: Container) -> None: else: node._set_value(type_) - if dest._is_missing(): + if dest._is_interpolation() or dest._is_missing(): expand(dest) for key, src_value in src.items_ex(resolve=False): @@ -387,7 +387,7 @@ def _merge_with( if isinstance(self, DictConfig) and isinstance(other, DictConfig): BaseContainer._map_merge(self, other) elif isinstance(self, ListConfig) and isinstance(other, ListConfig): - if self._is_none() or self._is_missing() or self._is_interpolation(): + if self._is_none() or self._is_interpolation() or self._is_missing(): self.__dict__["_content"] = [] else: self.__dict__["_content"].clear() diff --git a/tests/__init__.py b/tests/__init__.py index 72faaa35f..b630fbe8c 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -154,3 +154,13 @@ class SubscriptedList: @dataclass class SubscriptedDict: dict: Dict[str, int] = field(default_factory=lambda: {"foo": 4}) + + +@dataclass +class InterpolationList: + list: List[float] = "${optimization.lr}" # type: ignore + + +@dataclass +class InterpolationDict: + dict: Dict[str, int] = "${optimization.lr}" # type: ignore diff --git a/tests/test_merge.py b/tests/test_merge.py index 5944f4db3..ee249ca9e 100644 --- a/tests/test_merge.py +++ b/tests/test_merge.py @@ -23,6 +23,8 @@ ConfWithMissingDict, Group, IllegalType, + InterpolationDict, + InterpolationList, MissingDict, MissingList, Package, @@ -484,3 +486,23 @@ def test_merge_allow_objects() -> None: cfg._set_flag("allow_objects", True) ret = OmegaConf.merge(cfg, {"foo": iv}) assert ret == {"a": 10, "foo": iv} + + +@pytest.mark.parametrize( # type:ignore + "dst, other, expected", + [ + ( + OmegaConf.structured(InterpolationList), + OmegaConf.create({"list": [0.1]}), + {"list": [0.1]}, + ), + ( + OmegaConf.structured(InterpolationDict), + OmegaConf.create({"dict": {"a": 4}}), + {"dict": {"a": 4}}, + ), + ], +) +def test_merge_with_interpolation(dst: Any, other: Any, expected: Any) -> None: + res = OmegaConf.merge(dst, other) + assert res == expected From 9917717b6d22807cdd01c2508d29d0165daf4df6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pere=20D=C3=ADaz?= Date: Fri, 6 Nov 2020 22:18:43 +0100 Subject: [PATCH 2/6] test cases, interpolation as other fix --- omegaconf/basecontainer.py | 21 +++++++++++++-------- omegaconf/dictconfig.py | 2 +- omegaconf/omegaconf.py | 2 ++ tests/__init__.py | 4 ++-- tests/test_merge.py | 29 ++++++++++++++++++++++++++--- 5 files changed, 44 insertions(+), 14 deletions(-) diff --git a/omegaconf/basecontainer.py b/omegaconf/basecontainer.py index e34db8383..2c07d5551 100644 --- a/omegaconf/basecontainer.py +++ b/omegaconf/basecontainer.py @@ -281,8 +281,11 @@ def _map_merge(dest: "BaseContainer", src: "BaseContainer") -> None: assert isinstance(src, DictConfig) src_type = src._metadata.object_type + if src._is_interpolation(): + dest._set_value(src._value()) + return # if source DictConfig is missing set the DictConfig one to be missing too. - if src._is_missing(): + if src._is_missing(throw_on_resolution_failure=False): dest._set_value("???") return dest._validate_merge(key=None, value=src) @@ -387,12 +390,11 @@ def _merge_with( if isinstance(self, DictConfig) and isinstance(other, DictConfig): BaseContainer._map_merge(self, other) elif isinstance(self, ListConfig) and isinstance(other, ListConfig): - if self._is_none() or self._is_interpolation() or self._is_missing(): - self.__dict__["_content"] = [] - else: - self.__dict__["_content"].clear() + self.__dict__["_content"] = [] - if other._is_missing(): + if other._is_interpolation(): + self._set_value(other) + elif other._is_missing(): self._set_value("???") elif other._is_none(): self._set_value(None) @@ -575,9 +577,12 @@ def _item_eq( def _is_none(self) -> bool: return self.__dict__["_content"] is None - def _is_missing(self) -> bool: + def _is_missing(self, throw_on_resolution_failure: bool = True) -> bool: try: - self._dereference_node(throw_on_missing=True) + self._dereference_node( + throw_on_resolution_failure=throw_on_resolution_failure, + throw_on_missing=True, + ) return False except MissingMandatoryValue: ret = True diff --git a/omegaconf/dictconfig.py b/omegaconf/dictconfig.py index 7a21a518c..4a3d64dd6 100644 --- a/omegaconf/dictconfig.py +++ b/omegaconf/dictconfig.py @@ -445,7 +445,7 @@ def pop(self, key: Union[str, Enum], default: Any = DEFAULT_VALUE_MARKER) -> Any self._format_and_raise(key=key, value=None, cause=e) def keys(self) -> Any: - if self._is_missing() or self._is_interpolation() or self._is_none(): + if self._is_interpolation() or self._is_missing() or self._is_none(): return list() return self.__dict__["_content"].keys() diff --git a/omegaconf/omegaconf.py b/omegaconf/omegaconf.py index 690d08475..fe9b3b13a 100644 --- a/omegaconf/omegaconf.py +++ b/omegaconf/omegaconf.py @@ -563,6 +563,8 @@ def _get_obj_type(c: Any) -> Optional[Type[Any]]: elif isinstance(c, DictConfig): if c._is_none(): return None + elif c._is_interpolation(): + return None elif c._is_missing(): return None else: diff --git a/tests/__init__.py b/tests/__init__.py index b630fbe8c..aff8302d2 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -158,9 +158,9 @@ class SubscriptedDict: @dataclass class InterpolationList: - list: List[float] = "${optimization.lr}" # type: ignore + list: List[float] = II("optimization.lr") @dataclass class InterpolationDict: - dict: Dict[str, int] = "${optimization.lr}" # type: ignore + dict: Dict[str, int] = II("optimization.lr") diff --git a/tests/test_merge.py b/tests/test_merge.py index ee249ca9e..3d960795a 100644 --- a/tests/test_merge.py +++ b/tests/test_merge.py @@ -489,20 +489,43 @@ def test_merge_allow_objects() -> None: @pytest.mark.parametrize( # type:ignore - "dst, other, expected", + "dst, other, expected, node, check_missing", [ ( OmegaConf.structured(InterpolationList), OmegaConf.create({"list": [0.1]}), {"list": [0.1]}, + "list", + False, ), ( OmegaConf.structured(InterpolationDict), OmegaConf.create({"dict": {"a": 4}}), {"dict": {"a": 4}}, + "dict", + False, + ), + ( + OmegaConf.structured(InterpolationDict), + OmegaConf.structured(InterpolationDict), + None, + "dict", + True, + ), + ( + OmegaConf.structured(InterpolationList), + OmegaConf.structured(InterpolationList), + None, + "list", + True, ), ], ) -def test_merge_with_interpolation(dst: Any, other: Any, expected: Any) -> None: +def test_merge_with_interpolation( + dst: Any, other: Any, expected: Any, node: Any, check_missing: bool +) -> None: res = OmegaConf.merge(dst, other) - assert res == expected + if check_missing: + OmegaConf.is_missing(res, node) + else: + assert res == expected From 026b3857bc8d16bf92977fd6c84c7de992ace78c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pere=20D=C3=ADaz?= Date: Fri, 6 Nov 2020 22:20:20 +0100 Subject: [PATCH 3/6] news --- news/431.bugfix | 1 + 1 file changed, 1 insertion(+) create mode 100644 news/431.bugfix diff --git a/news/431.bugfix b/news/431.bugfix new file mode 100644 index 000000000..19af09c97 --- /dev/null +++ b/news/431.bugfix @@ -0,0 +1 @@ +Fix bug where interpolations were unnecessarily resolved during merge From 29b33044a6a7a6711785d9e3fa01a725a6ec81a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pere=20D=C3=ADaz?= Date: Sat, 7 Nov 2020 16:28:58 +0100 Subject: [PATCH 4/6] name tests and no trow on resolution failure as default --- omegaconf/basecontainer.py | 7 +++---- tests/test_merge.py | 12 ++++++++---- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/omegaconf/basecontainer.py b/omegaconf/basecontainer.py index 2c07d5551..b6a2e9c5f 100644 --- a/omegaconf/basecontainer.py +++ b/omegaconf/basecontainer.py @@ -285,7 +285,7 @@ def _map_merge(dest: "BaseContainer", src: "BaseContainer") -> None: dest._set_value(src._value()) return # if source DictConfig is missing set the DictConfig one to be missing too. - if src._is_missing(throw_on_resolution_failure=False): + if src._is_missing(): dest._set_value("???") return dest._validate_merge(key=None, value=src) @@ -577,11 +577,10 @@ def _item_eq( def _is_none(self) -> bool: return self.__dict__["_content"] is None - def _is_missing(self, throw_on_resolution_failure: bool = True) -> bool: + def _is_missing(self) -> bool: try: self._dereference_node( - throw_on_resolution_failure=throw_on_resolution_failure, - throw_on_missing=True, + throw_on_resolution_failure=False, throw_on_missing=True ) return False except MissingMandatoryValue: diff --git a/tests/test_merge.py b/tests/test_merge.py index 3d960795a..53c43d9b8 100644 --- a/tests/test_merge.py +++ b/tests/test_merge.py @@ -491,33 +491,37 @@ def test_merge_allow_objects() -> None: @pytest.mark.parametrize( # type:ignore "dst, other, expected, node, check_missing", [ - ( + pytest.param( OmegaConf.structured(InterpolationList), OmegaConf.create({"list": [0.1]}), {"list": [0.1]}, "list", False, + id="merge_interpolation_list_with_list", ), - ( + pytest.param( OmegaConf.structured(InterpolationDict), OmegaConf.create({"dict": {"a": 4}}), {"dict": {"a": 4}}, "dict", False, + id="merge_interpolation_dict_with_dict", ), - ( + pytest.param( OmegaConf.structured(InterpolationDict), OmegaConf.structured(InterpolationDict), None, "dict", True, + id="merge_interpolation_dict_with_interpolation_dict", ), - ( + pytest.param( OmegaConf.structured(InterpolationList), OmegaConf.structured(InterpolationList), None, "list", True, + id="merge_interpolation_list_with_interpolation_list", ), ], ) From e047a269ad7f9a51f2cbf28f6d21c4c22d940fae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pere=20D=C3=ADaz?= Date: Sun, 8 Nov 2020 13:08:45 +0100 Subject: [PATCH 5/6] split test and revert some changes --- omegaconf/basecontainer.py | 2 +- omegaconf/dictconfig.py | 2 +- omegaconf/omegaconf.py | 2 -- tests/test_merge.py | 29 +++++++++++++++-------------- 4 files changed, 17 insertions(+), 18 deletions(-) diff --git a/omegaconf/basecontainer.py b/omegaconf/basecontainer.py index b6a2e9c5f..d1b7b4c01 100644 --- a/omegaconf/basecontainer.py +++ b/omegaconf/basecontainer.py @@ -393,7 +393,7 @@ def _merge_with( self.__dict__["_content"] = [] if other._is_interpolation(): - self._set_value(other) + self._set_value(other._value()) elif other._is_missing(): self._set_value("???") elif other._is_none(): diff --git a/omegaconf/dictconfig.py b/omegaconf/dictconfig.py index 4a3d64dd6..7a21a518c 100644 --- a/omegaconf/dictconfig.py +++ b/omegaconf/dictconfig.py @@ -445,7 +445,7 @@ def pop(self, key: Union[str, Enum], default: Any = DEFAULT_VALUE_MARKER) -> Any self._format_and_raise(key=key, value=None, cause=e) def keys(self) -> Any: - if self._is_interpolation() or self._is_missing() or self._is_none(): + if self._is_missing() or self._is_interpolation() or self._is_none(): return list() return self.__dict__["_content"].keys() diff --git a/omegaconf/omegaconf.py b/omegaconf/omegaconf.py index fe9b3b13a..690d08475 100644 --- a/omegaconf/omegaconf.py +++ b/omegaconf/omegaconf.py @@ -563,8 +563,6 @@ def _get_obj_type(c: Any) -> Optional[Type[Any]]: elif isinstance(c, DictConfig): if c._is_none(): return None - elif c._is_interpolation(): - return None elif c._is_missing(): return None else: diff --git a/tests/test_merge.py b/tests/test_merge.py index 53c43d9b8..b4daa1196 100644 --- a/tests/test_merge.py +++ b/tests/test_merge.py @@ -489,14 +489,13 @@ def test_merge_allow_objects() -> None: @pytest.mark.parametrize( # type:ignore - "dst, other, expected, node, check_missing", + "dst, other, expected, node", [ pytest.param( OmegaConf.structured(InterpolationList), OmegaConf.create({"list": [0.1]}), {"list": [0.1]}, "list", - False, id="merge_interpolation_list_with_list", ), pytest.param( @@ -504,32 +503,34 @@ def test_merge_allow_objects() -> None: OmegaConf.create({"dict": {"a": 4}}), {"dict": {"a": 4}}, "dict", - False, id="merge_interpolation_dict_with_dict", ), + ], +) +def test_merge_with_src_as_interpolation( + dst: Any, other: Any, expected: Any, node: Any +) -> None: + res = OmegaConf.merge(dst, other) + assert res == expected + + +@pytest.mark.parametrize( # type:ignore + "dst, other, node", + [ pytest.param( OmegaConf.structured(InterpolationDict), OmegaConf.structured(InterpolationDict), - None, "dict", - True, id="merge_interpolation_dict_with_interpolation_dict", ), pytest.param( OmegaConf.structured(InterpolationList), OmegaConf.structured(InterpolationList), - None, "list", - True, id="merge_interpolation_list_with_interpolation_list", ), ], ) -def test_merge_with_interpolation( - dst: Any, other: Any, expected: Any, node: Any, check_missing: bool -) -> None: +def test_merge_with_other_as_interpolation(dst: Any, other: Any, node: Any) -> None: res = OmegaConf.merge(dst, other) - if check_missing: - OmegaConf.is_missing(res, node) - else: - assert res == expected + assert OmegaConf.is_interpolation(res, node) From a7b476540fb73bb1230e5a844730b6df4612b637 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pere=20D=C3=ADaz?= Date: Mon, 9 Nov 2020 11:32:18 +0100 Subject: [PATCH 6/6] add interpolation comment --- omegaconf/basecontainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/omegaconf/basecontainer.py b/omegaconf/basecontainer.py index d1b7b4c01..4f7eabbf3 100644 --- a/omegaconf/basecontainer.py +++ b/omegaconf/basecontainer.py @@ -281,6 +281,7 @@ def _map_merge(dest: "BaseContainer", src: "BaseContainer") -> None: assert isinstance(src, DictConfig) src_type = src._metadata.object_type + # if source DictConfig is an interpolation set the DictConfig one to be the same interpolation. if src._is_interpolation(): dest._set_value(src._value()) return