-
Notifications
You must be signed in to change notification settings - Fork 116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix bug where interpolations were unnecessarily resolved during merge. #432
Changes from 3 commits
b258b77
9917717
026b385
29b3304
e047a26
a7b4765
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
Fix bug where interpolations were unnecessarily resolved during merge |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -281,8 +281,11 @@ def _map_merge(dest: "BaseContainer", src: "BaseContainer") -> None: | |
assert isinstance(src, DictConfig) | ||
src_type = src._metadata.object_type | ||
|
||
if src._is_interpolation(): | ||
dest._set_value(src._value()) | ||
return | ||
# if source DictConfig is missing set the DictConfig one to be missing too. | ||
if src._is_missing(): | ||
if src._is_missing(throw_on_resolution_failure=False): | ||
dest._set_value("???") | ||
return | ||
dest._validate_merge(key=None, value=src) | ||
|
@@ -298,7 +301,7 @@ def expand(node: Container) -> None: | |
else: | ||
node._set_value(type_) | ||
|
||
if dest._is_missing(): | ||
if dest._is_interpolation() or dest._is_missing(): | ||
omry marked this conversation as resolved.
Show resolved
Hide resolved
|
||
expand(dest) | ||
|
||
for key, src_value in src.items_ex(resolve=False): | ||
|
@@ -387,12 +390,11 @@ def _merge_with( | |
if isinstance(self, DictConfig) and isinstance(other, DictConfig): | ||
BaseContainer._map_merge(self, other) | ||
elif isinstance(self, ListConfig) and isinstance(other, ListConfig): | ||
if self._is_none() or self._is_missing() or self._is_interpolation(): | ||
self.__dict__["_content"] = [] | ||
else: | ||
self.__dict__["_content"].clear() | ||
self.__dict__["_content"] = [] | ||
|
||
if other._is_missing(): | ||
if other._is_interpolation(): | ||
self._set_value(other) | ||
elif other._is_missing(): | ||
self._set_value("???") | ||
elif other._is_none(): | ||
self._set_value(None) | ||
|
@@ -575,9 +577,12 @@ def _item_eq( | |
def _is_none(self) -> bool: | ||
return self.__dict__["_content"] is None | ||
|
||
def _is_missing(self) -> bool: | ||
def _is_missing(self, throw_on_resolution_failure: bool = True) -> bool: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What if you change this function to not throw on resolution failure (without a parameter), does this cause any issues? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It doesn't cause any issue, It could be changed to that. |
||
try: | ||
self._dereference_node(throw_on_missing=True) | ||
self._dereference_node( | ||
throw_on_resolution_failure=throw_on_resolution_failure, | ||
throw_on_missing=True, | ||
) | ||
return False | ||
except MissingMandatoryValue: | ||
ret = True | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -445,7 +445,7 @@ def pop(self, key: Union[str, Enum], default: Any = DEFAULT_VALUE_MARKER) -> Any | |
self._format_and_raise(key=key, value=None, cause=e) | ||
|
||
def keys(self) -> Any: | ||
if self._is_missing() or self._is_interpolation() or self._is_none(): | ||
if self._is_interpolation() or self._is_missing() or self._is_none(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not that _is_missing dies not throw, you can undo this change. |
||
return list() | ||
return self.__dict__["_content"].keys() | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -563,6 +563,8 @@ def _get_obj_type(c: Any) -> Optional[Type[Any]]: | |
elif isinstance(c, DictConfig): | ||
if c._is_none(): | ||
return None | ||
elif c._is_interpolation(): | ||
return None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please explain. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now that is_missing doesn't throw on resolution failure it doesn't matter this line. I'll remove it. |
||
elif c._is_missing(): | ||
return None | ||
else: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,6 +23,8 @@ | |
ConfWithMissingDict, | ||
Group, | ||
IllegalType, | ||
InterpolationDict, | ||
InterpolationList, | ||
MissingDict, | ||
MissingList, | ||
Package, | ||
|
@@ -484,3 +486,46 @@ def test_merge_allow_objects() -> None: | |
cfg._set_flag("allow_objects", True) | ||
ret = OmegaConf.merge(cfg, {"foo": iv}) | ||
assert ret == {"a": 10, "foo": iv} | ||
|
||
|
||
@pytest.mark.parametrize( # type:ignore | ||
"dst, other, expected, node, check_missing", | ||
[ | ||
( | ||
OmegaConf.structured(InterpolationList), | ||
OmegaConf.create({"list": [0.1]}), | ||
{"list": [0.1]}, | ||
"list", | ||
False, | ||
), | ||
( | ||
OmegaConf.structured(InterpolationDict), | ||
OmegaConf.create({"dict": {"a": 4}}), | ||
{"dict": {"a": 4}}, | ||
"dict", | ||
False, | ||
), | ||
( | ||
OmegaConf.structured(InterpolationDict), | ||
OmegaConf.structured(InterpolationDict), | ||
None, | ||
"dict", | ||
True, | ||
), | ||
( | ||
OmegaConf.structured(InterpolationList), | ||
OmegaConf.structured(InterpolationList), | ||
None, | ||
"list", | ||
True, | ||
), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you name tests? |
||
], | ||
) | ||
def test_merge_with_interpolation( | ||
dst: Any, other: Any, expected: Any, node: Any, check_missing: bool | ||
) -> None: | ||
res = OmegaConf.merge(dst, other) | ||
if check_missing: | ||
OmegaConf.is_missing(res, node) | ||
else: | ||
assert res == expected | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do one thing in each test. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm changing it to is_interpolation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please explain this one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Merging an interpolation to a container is basically keeping the interpolation in dest therefore, dest's value should be src's value.
If I don't do this check here it will fail in
for key, src_value in src.items_ex(resolve=False):
(line 307).To sum up, it's the same as line 289 where
if src._is_missing(): dest._set_value("???")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok. See that there is a comment explaining it in 288.
Add a similar comment.