Skip to content

Commit

Permalink
Fix bug where interpolations were unnecessarily resolved during merge. (
Browse files Browse the repository at this point in the history
  • Loading branch information
pereman2 authored Nov 11, 2020
1 parent 7e1ff81 commit b0d59bf
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 7 deletions.
1 change: 1 addition & 0 deletions news/431.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix bug where interpolations were unnecessarily resolved during merge
19 changes: 12 additions & 7 deletions omegaconf/basecontainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,10 @@ def _map_merge(dest: "BaseContainer", src: "BaseContainer") -> None:
assert isinstance(src, DictConfig)
src_type = src._metadata.object_type

# if source DictConfig is an interpolation set the DictConfig one to be the same interpolation.
if src._is_interpolation():
dest._set_value(src._value())
return
# if source DictConfig is missing set the DictConfig one to be missing too.
if src._is_missing():
dest._set_value("???")
Expand All @@ -260,7 +264,7 @@ def expand(node: Container) -> None:
else:
node._set_value(type_)

if dest._is_missing():
if dest._is_interpolation() or dest._is_missing():
expand(dest)

for key, src_value in src.items_ex(resolve=False):
Expand Down Expand Up @@ -342,12 +346,11 @@ def _merge_with(
if isinstance(self, DictConfig) and isinstance(other, DictConfig):
BaseContainer._map_merge(self, other)
elif isinstance(self, ListConfig) and isinstance(other, ListConfig):
if self._is_none() or self._is_missing() or self._is_interpolation():
self.__dict__["_content"] = []
else:
self.__dict__["_content"].clear()
self.__dict__["_content"] = []

if other._is_missing():
if other._is_interpolation():
self._set_value(other._value())
elif other._is_missing():
self._set_value("???")
elif other._is_none():
self._set_value(None)
Expand Down Expand Up @@ -526,7 +529,9 @@ def _is_none(self) -> bool:

def _is_missing(self) -> bool:
try:
self._dereference_node(throw_on_missing=True)
self._dereference_node(
throw_on_resolution_failure=False, throw_on_missing=True
)
return False
except MissingMandatoryValue:
ret = True
Expand Down
10 changes: 10 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,13 @@ class Module:
@dataclass
class Package:
modules: List[Module] = MISSING


@dataclass
class InterpolationList:
list: List[float] = II("optimization.lr")


@dataclass
class InterpolationDict:
dict: Dict[str, int] = II("optimization.lr")
50 changes: 50 additions & 0 deletions tests/test_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
ConcretePlugin,
ConfWithMissingDict,
Group,
InterpolationDict,
InterpolationList,
MissingDict,
MissingList,
Package,
Expand Down Expand Up @@ -472,3 +474,51 @@ def test_merge_with_dotlist_errors(dotlist: List[str]) -> None:
c = OmegaConf.create()
with pytest.raises(ValueError):
c.merge_with_dotlist(dotlist)


@pytest.mark.parametrize( # type:ignore
"dst, other, expected, node",
[
pytest.param(
OmegaConf.structured(InterpolationList),
OmegaConf.create({"list": [0.1]}),
{"list": [0.1]},
"list",
id="merge_interpolation_list_with_list",
),
pytest.param(
OmegaConf.structured(InterpolationDict),
OmegaConf.create({"dict": {"a": 4}}),
{"dict": {"a": 4}},
"dict",
id="merge_interpolation_dict_with_dict",
),
],
)
def test_merge_with_src_as_interpolation(
dst: Any, other: Any, expected: Any, node: Any
) -> None:
res = OmegaConf.merge(dst, other)
assert res == expected


@pytest.mark.parametrize( # type:ignore
"dst, other, node",
[
pytest.param(
OmegaConf.structured(InterpolationDict),
OmegaConf.structured(InterpolationDict),
"dict",
id="merge_interpolation_dict_with_interpolation_dict",
),
pytest.param(
OmegaConf.structured(InterpolationList),
OmegaConf.structured(InterpolationList),
"list",
id="merge_interpolation_list_with_interpolation_list",
),
],
)
def test_merge_with_other_as_interpolation(dst: Any, other: Any, node: Any) -> None:
res = OmegaConf.merge(dst, other)
assert OmegaConf.is_interpolation(res, node)

0 comments on commit b0d59bf

Please sign in to comment.