From 0f6ad9102d9469cd5ff024437d0714a5c2623b4c Mon Sep 17 00:00:00 2001 From: Omry Yadan Date: Sun, 11 Apr 2021 18:03:37 -0700 Subject: [PATCH] refactor select_value and select_node out of OmegaConf.select --- omegaconf/_impl.py | 69 +++++++++++++++++++++++++++++++++- omegaconf/omegaconf.py | 46 +++++------------------ omegaconf/resolvers/oc/dict.py | 4 +- tests/test_select.py | 9 +++-- 4 files changed, 84 insertions(+), 44 deletions(-) diff --git a/omegaconf/_impl.py b/omegaconf/_impl.py index 61207e0cc..3f9deba49 100644 --- a/omegaconf/_impl.py +++ b/omegaconf/_impl.py @@ -1,7 +1,9 @@ from typing import Any from omegaconf import MISSING, Container, DictConfig, ListConfig, Node, ValueNode -from omegaconf.errors import InterpolationToMissingValueError +from omegaconf.errors import ConfigKeyError, InterpolationToMissingValueError + +from ._utils import _DEFAULT_MARKER_, _get_value def _resolve_container_value(cfg: Container, key: Any) -> None: @@ -42,3 +44,68 @@ def _resolve(cfg: Node) -> Node: _resolve_container_value(cfg, i) return cfg + + +def select_value( + cfg: Container, + key: str, + *, + default: Any = _DEFAULT_MARKER_, + throw_on_resolution_failure: bool = True, + throw_on_missing: bool = False, + absolute_key: bool = False, +) -> Any: + ret = select_node( + cfg=cfg, + key=key, + default=default, + throw_on_resolution_failure=throw_on_resolution_failure, + throw_on_missing=throw_on_missing, + absolute_key=absolute_key, + ) + if isinstance(ret, Node) and ret._is_missing(): + return None + + return _get_value(ret) + + +def select_node( + cfg: Container, + key: str, + *, + default: Any = _DEFAULT_MARKER_, + throw_on_resolution_failure: bool = True, + throw_on_missing: bool = False, + absolute_key: bool = False, +) -> Any: + try: + # for non relative keys, the interpretation can be: + # 1. relative to cfg + # 2. relative to the config root + # This is controlled by the absolute_key flag. By default, such keys are relative to cfg. + if not absolute_key and not key.startswith("."): + key = f".{key}" + + cfg, key = cfg._resolve_key_and_root(key) + _root, _last_key, value = cfg._select_impl( + key, + throw_on_missing=throw_on_missing, + throw_on_resolution_failure=throw_on_resolution_failure, + ) + except ConfigKeyError: + if default is not _DEFAULT_MARKER_: + return default + else: + raise + + if ( + default is not _DEFAULT_MARKER_ + and _root is not None + and _last_key is not None + and _last_key not in _root + ): + return default + + if value is not None and value._is_missing(): + return None + return value diff --git a/omegaconf/omegaconf.py b/omegaconf/omegaconf.py index eded511fb..e804fe022 100644 --- a/omegaconf/omegaconf.py +++ b/omegaconf/omegaconf.py @@ -30,7 +30,6 @@ from ._utils import ( _DEFAULT_MARKER_, _ensure_container, - _get_value, _is_none, format_and_raise, get_dict_key_value_types, @@ -53,7 +52,6 @@ from .base import Container, Node, SCMode from .basecontainer import BaseContainer from .errors import ( - ConfigKeyError, MissingMandatoryValue, OmegaConfBaseException, UnsupportedInterpolationType, @@ -653,7 +651,6 @@ def select( default: Any = _DEFAULT_MARKER_, throw_on_resolution_failure: bool = True, throw_on_missing: bool = False, - absolute_key: bool = False, ) -> Any: """ :param cfg: Config node to select from @@ -663,43 +660,18 @@ def select( resolution error occurs, otherwise return None :param throw_on_missing: Raise an exception if an attempt to select a missing key (with the value '???') is made, otherwise return None - :param absolute_key: True to treat non-relative keys as relative to the config root - False (default) to treat non-relative keys as relative to cfg :return: selected value or None if not found. """ - try: - try: - # for non relative keys, the interpretation can be: - # 1. relative to cfg - # 2. relative to the config root - # This is controlled by the absolute_key flag. By default, such keys are relative to cfg. - if not absolute_key and not key.startswith("."): - key = f".{key}" - - cfg, key = cfg._resolve_key_and_root(key) - _root, _last_key, value = cfg._select_impl( - key, - throw_on_missing=throw_on_missing, - throw_on_resolution_failure=throw_on_resolution_failure, - ) - except ConfigKeyError: - if default is not _DEFAULT_MARKER_: - return default - else: - raise - - if ( - default is not _DEFAULT_MARKER_ - and _root is not None - and _last_key is not None - and _last_key not in _root - ): - return default + from ._impl import select_value - if value is not None and value._is_missing(): - return None - - return _get_value(value) + try: + return select_value( + cfg=cfg, + key=key, + default=default, + throw_on_resolution_failure=throw_on_resolution_failure, + throw_on_missing=throw_on_missing, + ) except Exception as e: format_and_raise(node=cfg, key=key, value=None, cause=e, msg=str(e)) diff --git a/omegaconf/resolvers/oc/dict.py b/omegaconf/resolvers/oc/dict.py index 927082e2f..276e79a30 100644 --- a/omegaconf/resolvers/oc/dict.py +++ b/omegaconf/resolvers/oc/dict.py @@ -53,7 +53,7 @@ def _get_and_validate_dict_input( parent: BaseContainer, resolver_name: str, ) -> DictConfig: - from omegaconf import OmegaConf + from omegaconf._impl import select_value if not isinstance(key, str): raise TypeError( @@ -61,7 +61,7 @@ def _get_and_validate_dict_input( f"of type: {type(key).__name__}" ) - in_dict = OmegaConf.select( + in_dict = select_value( parent, key, throw_on_missing=True, diff --git a/tests/test_select.py b/tests/test_select.py index 51658cdc1..d4caffb12 100644 --- a/tests/test_select.py +++ b/tests/test_select.py @@ -5,6 +5,7 @@ from pytest import mark, param, raises from omegaconf import MissingMandatoryValue, OmegaConf +from omegaconf._impl import select_value from omegaconf._utils import _ensure_container from omegaconf.errors import ConfigKeyError, InterpolationKeyError @@ -265,8 +266,8 @@ def test_select_from_nested_node_with_a_relative_key( ) -> None: cfg = OmegaConf.create(inp) # select returns the same result when a key is relative independent of absolute_key flag. - assert OmegaConf.select(cfg.a, key, absolute_key=False) == expected - assert OmegaConf.select(cfg.a, key, absolute_key=True) == expected + assert select_value(cfg.a, key, absolute_key=False) == expected + assert select_value(cfg.a, key, absolute_key=True) == expected @mark.parametrize( ("key", "expected"), @@ -282,7 +283,7 @@ def test_select_from_nested_node_relative_key_interpretation( self, key: str, expected: Any ) -> None: cfg = OmegaConf.create(inp) - assert OmegaConf.select(cfg.a, key, absolute_key=False) == expected + assert select_value(cfg.a, key, absolute_key=False) == expected @mark.parametrize( ("key", "expected"), @@ -300,4 +301,4 @@ def test_select_from_nested_node_absolute_key_interpretation( self, key: str, expected: Any ) -> None: cfg = OmegaConf.create(inp) - assert OmegaConf.select(cfg.a, key, absolute_key=True) == expected + assert select_value(cfg.a, key, absolute_key=True) == expected