diff --git a/news/431.bugfix b/news/431.bugfix new file mode 100644 index 000000000..19af09c97 --- /dev/null +++ b/news/431.bugfix @@ -0,0 +1 @@ +Fix bug where interpolations were unnecessarily resolved during merge diff --git a/omegaconf/basecontainer.py b/omegaconf/basecontainer.py index d4ea0e30e..e31cf32ef 100644 --- a/omegaconf/basecontainer.py +++ b/omegaconf/basecontainer.py @@ -243,6 +243,10 @@ def _map_merge(dest: "BaseContainer", src: "BaseContainer") -> None: assert isinstance(src, DictConfig) src_type = src._metadata.object_type + # if source DictConfig is an interpolation set the DictConfig one to be the same interpolation. + if src._is_interpolation(): + dest._set_value(src._value()) + return # if source DictConfig is missing set the DictConfig one to be missing too. if src._is_missing(): dest._set_value("???") @@ -260,7 +264,7 @@ def expand(node: Container) -> None: else: node._set_value(type_) - if dest._is_missing(): + if dest._is_interpolation() or dest._is_missing(): expand(dest) for key, src_value in src.items_ex(resolve=False): @@ -342,12 +346,11 @@ def _merge_with( if isinstance(self, DictConfig) and isinstance(other, DictConfig): BaseContainer._map_merge(self, other) elif isinstance(self, ListConfig) and isinstance(other, ListConfig): - if self._is_none() or self._is_missing() or self._is_interpolation(): - self.__dict__["_content"] = [] - else: - self.__dict__["_content"].clear() + self.__dict__["_content"] = [] - if other._is_missing(): + if other._is_interpolation(): + self._set_value(other._value()) + elif other._is_missing(): self._set_value("???") elif other._is_none(): self._set_value(None) @@ -526,7 +529,9 @@ def _is_none(self) -> bool: def _is_missing(self) -> bool: try: - self._dereference_node(throw_on_missing=True) + self._dereference_node( + throw_on_resolution_failure=False, throw_on_missing=True + ) return False except MissingMandatoryValue: ret = True diff --git a/tests/__init__.py b/tests/__init__.py index 1e21dca55..d8b600faa 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -144,3 +144,13 @@ class Module: @dataclass class Package: modules: List[Module] = MISSING + + +@dataclass +class InterpolationList: + list: List[float] = II("optimization.lr") + + +@dataclass +class InterpolationDict: + dict: Dict[str, int] = II("optimization.lr") diff --git a/tests/test_merge.py b/tests/test_merge.py index 9e2a867de..617dbf491 100644 --- a/tests/test_merge.py +++ b/tests/test_merge.py @@ -22,6 +22,8 @@ ConcretePlugin, ConfWithMissingDict, Group, + InterpolationDict, + InterpolationList, MissingDict, MissingList, Package, @@ -472,3 +474,51 @@ def test_merge_with_dotlist_errors(dotlist: List[str]) -> None: c = OmegaConf.create() with pytest.raises(ValueError): c.merge_with_dotlist(dotlist) + + +@pytest.mark.parametrize( # type:ignore + "dst, other, expected, node", + [ + pytest.param( + OmegaConf.structured(InterpolationList), + OmegaConf.create({"list": [0.1]}), + {"list": [0.1]}, + "list", + id="merge_interpolation_list_with_list", + ), + pytest.param( + OmegaConf.structured(InterpolationDict), + OmegaConf.create({"dict": {"a": 4}}), + {"dict": {"a": 4}}, + "dict", + id="merge_interpolation_dict_with_dict", + ), + ], +) +def test_merge_with_src_as_interpolation( + dst: Any, other: Any, expected: Any, node: Any +) -> None: + res = OmegaConf.merge(dst, other) + assert res == expected + + +@pytest.mark.parametrize( # type:ignore + "dst, other, node", + [ + pytest.param( + OmegaConf.structured(InterpolationDict), + OmegaConf.structured(InterpolationDict), + "dict", + id="merge_interpolation_dict_with_interpolation_dict", + ), + pytest.param( + OmegaConf.structured(InterpolationList), + OmegaConf.structured(InterpolationList), + "list", + id="merge_interpolation_list_with_interpolation_list", + ), + ], +) +def test_merge_with_other_as_interpolation(dst: Any, other: Any, node: Any) -> None: + res = OmegaConf.merge(dst, other) + assert OmegaConf.is_interpolation(res, node)