diff --git a/news/431.bugfix b/news/431.bugfix new file mode 100644 index 000000000..19af09c97 --- /dev/null +++ b/news/431.bugfix @@ -0,0 +1 @@ +Fix bug where interpolations were unnecessarily resolved during merge diff --git a/omegaconf/basecontainer.py b/omegaconf/basecontainer.py index 05d1d2f02..4f7eabbf3 100644 --- a/omegaconf/basecontainer.py +++ b/omegaconf/basecontainer.py @@ -281,6 +281,10 @@ def _map_merge(dest: "BaseContainer", src: "BaseContainer") -> None: assert isinstance(src, DictConfig) src_type = src._metadata.object_type + # if source DictConfig is an interpolation set the DictConfig one to be the same interpolation. + if src._is_interpolation(): + dest._set_value(src._value()) + return # if source DictConfig is missing set the DictConfig one to be missing too. if src._is_missing(): dest._set_value("???") @@ -298,7 +302,7 @@ def expand(node: Container) -> None: else: node._set_value(type_) - if dest._is_missing(): + if dest._is_interpolation() or dest._is_missing(): expand(dest) for key, src_value in src.items_ex(resolve=False): @@ -387,12 +391,11 @@ def _merge_with( if isinstance(self, DictConfig) and isinstance(other, DictConfig): BaseContainer._map_merge(self, other) elif isinstance(self, ListConfig) and isinstance(other, ListConfig): - if self._is_none() or self._is_missing() or self._is_interpolation(): - self.__dict__["_content"] = [] - else: - self.__dict__["_content"].clear() + self.__dict__["_content"] = [] - if other._is_missing(): + if other._is_interpolation(): + self._set_value(other._value()) + elif other._is_missing(): self._set_value("???") elif other._is_none(): self._set_value(None) @@ -577,7 +580,9 @@ def _is_none(self) -> bool: def _is_missing(self) -> bool: try: - self._dereference_node(throw_on_missing=True) + self._dereference_node( + throw_on_resolution_failure=False, throw_on_missing=True + ) return False except MissingMandatoryValue: ret = True diff --git a/tests/__init__.py b/tests/__init__.py index 72faaa35f..aff8302d2 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -154,3 +154,13 @@ class SubscriptedList: @dataclass class SubscriptedDict: dict: Dict[str, int] = field(default_factory=lambda: {"foo": 4}) + + +@dataclass +class InterpolationList: + list: List[float] = II("optimization.lr") + + +@dataclass +class InterpolationDict: + dict: Dict[str, int] = II("optimization.lr") diff --git a/tests/test_merge.py b/tests/test_merge.py index 5944f4db3..b4daa1196 100644 --- a/tests/test_merge.py +++ b/tests/test_merge.py @@ -23,6 +23,8 @@ ConfWithMissingDict, Group, IllegalType, + InterpolationDict, + InterpolationList, MissingDict, MissingList, Package, @@ -484,3 +486,51 @@ def test_merge_allow_objects() -> None: cfg._set_flag("allow_objects", True) ret = OmegaConf.merge(cfg, {"foo": iv}) assert ret == {"a": 10, "foo": iv} + + +@pytest.mark.parametrize( # type:ignore + "dst, other, expected, node", + [ + pytest.param( + OmegaConf.structured(InterpolationList), + OmegaConf.create({"list": [0.1]}), + {"list": [0.1]}, + "list", + id="merge_interpolation_list_with_list", + ), + pytest.param( + OmegaConf.structured(InterpolationDict), + OmegaConf.create({"dict": {"a": 4}}), + {"dict": {"a": 4}}, + "dict", + id="merge_interpolation_dict_with_dict", + ), + ], +) +def test_merge_with_src_as_interpolation( + dst: Any, other: Any, expected: Any, node: Any +) -> None: + res = OmegaConf.merge(dst, other) + assert res == expected + + +@pytest.mark.parametrize( # type:ignore + "dst, other, node", + [ + pytest.param( + OmegaConf.structured(InterpolationDict), + OmegaConf.structured(InterpolationDict), + "dict", + id="merge_interpolation_dict_with_interpolation_dict", + ), + pytest.param( + OmegaConf.structured(InterpolationList), + OmegaConf.structured(InterpolationList), + "list", + id="merge_interpolation_list_with_interpolation_list", + ), + ], +) +def test_merge_with_other_as_interpolation(dst: Any, other: Any, node: Any) -> None: + res = OmegaConf.merge(dst, other) + assert OmegaConf.is_interpolation(res, node)