Skip to content

Commit

Permalink
Refactoring: _maybe_dereference_node (#668)
Browse files Browse the repository at this point in the history
* create _dereference_node_impl

* introduce _maybe_dereference_node

* change return type for _dereference_node

* _dereference_node_impl: change fn signature (remove default argument)
  • Loading branch information
Jasha10 authored Apr 9, 2021
1 parent a22b70c commit 9d12412
Show file tree
Hide file tree
Showing 8 changed files with 29 additions and 23 deletions.
2 changes: 0 additions & 2 deletions omegaconf/_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ def _resolve_container_value(cfg: Container, key: Any) -> None:
except InterpolationToMissingValueError:
node._set_value(MISSING)
else:
assert resolved is not None
if isinstance(resolved, Container):
_resolve(resolved)
if isinstance(resolved, Container) and isinstance(node, ValueNode):
Expand All @@ -32,7 +31,6 @@ def _resolve(cfg: Node) -> Node:
except InterpolationToMissingValueError:
cfg._set_value(MISSING)
else:
assert resolved is not None
cfg._set_value(resolved._value())

if isinstance(cfg, DictConfig):
Expand Down
2 changes: 1 addition & 1 deletion omegaconf/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def _is_none(
return value is None

if resolve:
value = value._dereference_node(
value = value._maybe_dereference_node(
throw_on_resolution_failure=throw_on_resolution_failure
)
if not throw_on_resolution_failure and value is None:
Expand Down
18 changes: 15 additions & 3 deletions omegaconf/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,21 @@ def _format_and_raise(
def _get_full_key(self, key: Optional[Union[DictKeyType, int]]) -> str:
...

def _dereference_node(
def _dereference_node(self) -> "Node":
node = self._dereference_node_impl(throw_on_resolution_failure=True)
assert node is not None
return node

def _maybe_dereference_node(
self,
throw_on_resolution_failure: bool = True,
throw_on_resolution_failure: bool = False,
) -> Optional["Node"]:
return self._dereference_node_impl(
throw_on_resolution_failure=throw_on_resolution_failure
)

def _dereference_node_impl(
self, throw_on_resolution_failure: bool
) -> Optional["Node"]:
if not self._is_interpolation():
return self
Expand Down Expand Up @@ -370,7 +382,7 @@ def _select_impl(
throw_on_type_error=throw_on_resolution_failure,
)
if isinstance(ret, Node):
ret = ret._dereference_node(
ret = ret._maybe_dereference_node(
throw_on_resolution_failure=throw_on_resolution_failure,
)

Expand Down
14 changes: 5 additions & 9 deletions omegaconf/basecontainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,8 @@ def convert(val: Node) -> Any:
node = conf._get_node(key)
assert isinstance(node, Node)
if resolve:
node = node._dereference_node(throw_on_resolution_failure=True)
node = node._dereference_node()

assert node is not None
if enum_to_str and isinstance(key, Enum):
key = f"{key.name}"
if isinstance(node, Container):
Expand All @@ -220,8 +219,7 @@ def convert(val: Node) -> Any:
node = conf._get_node(index)
assert isinstance(node, Node)
if resolve:
node = node._dereference_node(throw_on_resolution_failure=True)
assert node is not None
node = node._dereference_node()
if isinstance(node, Container):
item = BaseContainer._to_content(
node,
Expand Down Expand Up @@ -310,9 +308,7 @@ def expand(node: Container) -> None:
expand(dest_node)

if dest_node is not None and dest_node._is_interpolation():
target_node = dest_node._dereference_node(
throw_on_resolution_failure=False
)
target_node = dest_node._maybe_dereference_node()
if isinstance(target_node, Container):
dest[key] = target_node
dest_node = dest._get_node(key)
Expand Down Expand Up @@ -591,9 +587,9 @@ def _item_eq(
dv2: Optional[Node] = v2

if v1_inter:
dv1 = v1._dereference_node(throw_on_resolution_failure=False)
dv1 = v1._maybe_dereference_node()
if v2_inter:
dv2 = v2._dereference_node(throw_on_resolution_failure=False)
dv2 = v2._maybe_dereference_node()

if v1_inter and v2_inter:
if dv1 is None or dv2 is None:
Expand Down
3 changes: 1 addition & 2 deletions omegaconf/dictconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,8 +700,7 @@ def _to_object(self) -> Any:
for k in self.keys():
node = self._get_node(k)
assert isinstance(node, Node)
node = node._dereference_node(throw_on_resolution_failure=True)
assert node is not None
node = node._dereference_node()
if isinstance(node, Container):
v = BaseContainer._to_content(
node,
Expand Down
2 changes: 1 addition & 1 deletion pydevd_plugins/extensions/pydevd_plugin_omegaconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def get_dictionary(self, obj: Any) -> Dict[str, Any]:

def _get_dictionary(self, obj: Any) -> Dict[str, Any]:
if isinstance(obj, self.Node):
obj = obj._dereference_node(throw_on_resolution_failure=False)
obj = obj._maybe_dereference_node()
if obj is None or obj._is_none() or obj._is_missing():
return {}

Expand Down
9 changes: 5 additions & 4 deletions tests/interpolation/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def test_resolve_interpolation_without_parent() -> None:

def test_resolve_interpolation_without_parent_no_throw() -> None:
cfg = DictConfig(content="${foo}")
assert cfg._dereference_node(throw_on_resolution_failure=False) is None
assert cfg._maybe_dereference_node() is None


def test_optional_after_interpolation() -> None:
Expand All @@ -227,7 +227,8 @@ def test_invalid_intermediate_result_when_not_throwing(
The main goal of this test is to make sure that the resolution of an interpolation
is stopped immediately when a missing / resolution failure occurs, even if
`throw_on_resolution_failure` is set to False.
`_maybe_dereference_node(throw_on_resolution_failure=False)` is used
instead of `_dereference_node`.
When this happens while dereferencing a node, the result should be `None`.
"""

Expand All @@ -243,7 +244,7 @@ def fail_if_called(x: Any) -> None:
)
x_node = cfg._get_node("x")
assert isinstance(x_node, Node)
assert x_node._dereference_node(throw_on_resolution_failure=False) is None
assert x_node._maybe_dereference_node(throw_on_resolution_failure=False) is None


def test_none_value_in_quoted_string(restore_resolvers: Any) -> None:
Expand Down Expand Up @@ -442,4 +443,4 @@ def test_interpolation_readonly_node() -> None:
def test_type_validation_error_no_throw() -> None:
cfg = OmegaConf.structured(User(name="Bond", age=SI("${name}")))
bad_node = cfg._get_node("age")
assert bad_node._dereference_node(throw_on_resolution_failure=False) is None
assert bad_node._maybe_dereference_node() is None
2 changes: 1 addition & 1 deletion tests/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,6 @@ def test_dereference_interpolation_to_missing() -> None:
cfg = OmegaConf.create({"x": "${y}", "y": "???"})
x_node = cfg._get_node("x")
assert isinstance(x_node, Node)
assert x_node._dereference_node(throw_on_resolution_failure=False) is None
assert x_node._maybe_dereference_node() is None
with raises(InterpolationToMissingValueError):
cfg.x

0 comments on commit 9d12412

Please sign in to comment.