Skip to content

Commit

Permalink
refactor select_value and select_node out of OmegaConf.select
Browse files Browse the repository at this point in the history
  • Loading branch information
omry committed Apr 15, 2021
1 parent 5699dec commit 0f6ad91
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 44 deletions.
69 changes: 68 additions & 1 deletion omegaconf/_impl.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import Any

from omegaconf import MISSING, Container, DictConfig, ListConfig, Node, ValueNode
from omegaconf.errors import InterpolationToMissingValueError
from omegaconf.errors import ConfigKeyError, InterpolationToMissingValueError

from ._utils import _DEFAULT_MARKER_, _get_value


def _resolve_container_value(cfg: Container, key: Any) -> None:
Expand Down Expand Up @@ -42,3 +44,68 @@ def _resolve(cfg: Node) -> Node:
_resolve_container_value(cfg, i)

return cfg


def select_value(
cfg: Container,
key: str,
*,
default: Any = _DEFAULT_MARKER_,
throw_on_resolution_failure: bool = True,
throw_on_missing: bool = False,
absolute_key: bool = False,
) -> Any:
ret = select_node(
cfg=cfg,
key=key,
default=default,
throw_on_resolution_failure=throw_on_resolution_failure,
throw_on_missing=throw_on_missing,
absolute_key=absolute_key,
)
if isinstance(ret, Node) and ret._is_missing():
return None

return _get_value(ret)


def select_node(
cfg: Container,
key: str,
*,
default: Any = _DEFAULT_MARKER_,
throw_on_resolution_failure: bool = True,
throw_on_missing: bool = False,
absolute_key: bool = False,
) -> Any:
try:
# for non relative keys, the interpretation can be:
# 1. relative to cfg
# 2. relative to the config root
# This is controlled by the absolute_key flag. By default, such keys are relative to cfg.
if not absolute_key and not key.startswith("."):
key = f".{key}"

cfg, key = cfg._resolve_key_and_root(key)
_root, _last_key, value = cfg._select_impl(
key,
throw_on_missing=throw_on_missing,
throw_on_resolution_failure=throw_on_resolution_failure,
)
except ConfigKeyError:
if default is not _DEFAULT_MARKER_:
return default
else:
raise

if (
default is not _DEFAULT_MARKER_
and _root is not None
and _last_key is not None
and _last_key not in _root
):
return default

if value is not None and value._is_missing():
return None
return value
46 changes: 9 additions & 37 deletions omegaconf/omegaconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from ._utils import (
_DEFAULT_MARKER_,
_ensure_container,
_get_value,
_is_none,
format_and_raise,
get_dict_key_value_types,
Expand All @@ -53,7 +52,6 @@
from .base import Container, Node, SCMode
from .basecontainer import BaseContainer
from .errors import (
ConfigKeyError,
MissingMandatoryValue,
OmegaConfBaseException,
UnsupportedInterpolationType,
Expand Down Expand Up @@ -653,7 +651,6 @@ def select(
default: Any = _DEFAULT_MARKER_,
throw_on_resolution_failure: bool = True,
throw_on_missing: bool = False,
absolute_key: bool = False,
) -> Any:
"""
:param cfg: Config node to select from
Expand All @@ -663,43 +660,18 @@ def select(
resolution error occurs, otherwise return None
:param throw_on_missing: Raise an exception if an attempt to select a missing key (with the value '???')
is made, otherwise return None
:param absolute_key: True to treat non-relative keys as relative to the config root
False (default) to treat non-relative keys as relative to cfg
:return: selected value or None if not found.
"""
try:
try:
# for non relative keys, the interpretation can be:
# 1. relative to cfg
# 2. relative to the config root
# This is controlled by the absolute_key flag. By default, such keys are relative to cfg.
if not absolute_key and not key.startswith("."):
key = f".{key}"

cfg, key = cfg._resolve_key_and_root(key)
_root, _last_key, value = cfg._select_impl(
key,
throw_on_missing=throw_on_missing,
throw_on_resolution_failure=throw_on_resolution_failure,
)
except ConfigKeyError:
if default is not _DEFAULT_MARKER_:
return default
else:
raise

if (
default is not _DEFAULT_MARKER_
and _root is not None
and _last_key is not None
and _last_key not in _root
):
return default
from ._impl import select_value

if value is not None and value._is_missing():
return None

return _get_value(value)
try:
return select_value(
cfg=cfg,
key=key,
default=default,
throw_on_resolution_failure=throw_on_resolution_failure,
throw_on_missing=throw_on_missing,
)
except Exception as e:
format_and_raise(node=cfg, key=key, value=None, cause=e, msg=str(e))

Expand Down
4 changes: 2 additions & 2 deletions omegaconf/resolvers/oc/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,15 @@ def _get_and_validate_dict_input(
parent: BaseContainer,
resolver_name: str,
) -> DictConfig:
from omegaconf import OmegaConf
from omegaconf._impl import select_value

if not isinstance(key, str):
raise TypeError(
f"`{resolver_name}` requires a string as input, but obtained `{key}` "
f"of type: {type(key).__name__}"
)

in_dict = OmegaConf.select(
in_dict = select_value(
parent,
key,
throw_on_missing=True,
Expand Down
9 changes: 5 additions & 4 deletions tests/test_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pytest import mark, param, raises

from omegaconf import MissingMandatoryValue, OmegaConf
from omegaconf._impl import select_value
from omegaconf._utils import _ensure_container
from omegaconf.errors import ConfigKeyError, InterpolationKeyError

Expand Down Expand Up @@ -265,8 +266,8 @@ def test_select_from_nested_node_with_a_relative_key(
) -> None:
cfg = OmegaConf.create(inp)
# select returns the same result when a key is relative independent of absolute_key flag.
assert OmegaConf.select(cfg.a, key, absolute_key=False) == expected
assert OmegaConf.select(cfg.a, key, absolute_key=True) == expected
assert select_value(cfg.a, key, absolute_key=False) == expected
assert select_value(cfg.a, key, absolute_key=True) == expected

@mark.parametrize(
("key", "expected"),
Expand All @@ -282,7 +283,7 @@ def test_select_from_nested_node_relative_key_interpretation(
self, key: str, expected: Any
) -> None:
cfg = OmegaConf.create(inp)
assert OmegaConf.select(cfg.a, key, absolute_key=False) == expected
assert select_value(cfg.a, key, absolute_key=False) == expected

@mark.parametrize(
("key", "expected"),
Expand All @@ -300,4 +301,4 @@ def test_select_from_nested_node_absolute_key_interpretation(
self, key: str, expected: Any
) -> None:
cfg = OmegaConf.create(inp)
assert OmegaConf.select(cfg.a, key, absolute_key=True) == expected
assert select_value(cfg.a, key, absolute_key=True) == expected

0 comments on commit 0f6ad91

Please sign in to comment.