diff --git a/omegaconf/omegaconf.py b/omegaconf/omegaconf.py index 1f6f76713..662fb6122 100644 --- a/omegaconf/omegaconf.py +++ b/omegaconf/omegaconf.py @@ -676,9 +676,29 @@ def select( default: Any = _DEFAULT_MARKER_, throw_on_resolution_failure: bool = True, throw_on_missing: bool = False, + absolute_key: bool = False, ) -> Any: + """ + :param cfg: Config node to select from + :param key: Key to select + :param default: Default value to return if key is not found + :param throw_on_resolution_failure: Raise an exception if an interpolation + resolution error occurs, otherwise return None + :param throw_on_missing: Raise an exception if an attempt to select a missing key (with the value '???') + is made, otherwise return None + :param absolute_key: True to treat non-relative keys as relative to the config root + False (default) to treat non-relative keys as relative to cfg + :return: selected value or None if not found. + """ try: try: + # for non relative keys, the interpretation can be: + # 1. relative to cfg + # 2. relative to the config root + # This is controlled by the absolute_key flag. By default, such keys are relative to cfg. + if not absolute_key and not key.startswith("."): + key = f".{key}" + cfg, key = cfg._resolve_key_and_root(key) _root, _last_key, value = cfg._select_impl( key, diff --git a/tests/test_select.py b/tests/test_select.py index 11c093156..f3268e8ab 100644 --- a/tests/test_select.py +++ b/tests/test_select.py @@ -193,16 +193,6 @@ def test_select_deprecated(self, struct: Optional[bool]) -> None: ): cfg.select("foo") - def test_select_relative_from_nested_node(self, struct: Optional[bool]) -> None: - cfg = OmegaConf.create( - {"a": {"b": {"c": 10}}, "z": 10}, - ) - OmegaConf.set_struct(cfg, struct) - assert OmegaConf.select(cfg.a, ".") == {"b": {"c": 10}} - assert OmegaConf.select(cfg.a, "..") == {"a": {"b": {"c": 10}}, "z": 10} - assert OmegaConf.select(cfg.a, "..a") == {"b": {"c": 10}} - assert OmegaConf.select(cfg.a, "..z") == 10 - @mark.parametrize( "cfg,key,expected", @@ -263,11 +253,62 @@ def test_select_resolves_interpolation(cfg: Any, key: str, expected: Any) -> Non assert OmegaConf.select(cfg, key) == expected -def test_select_relative_from_nested_node() -> None: - cfg = OmegaConf.create( - {"a": {"b": {"c": 10}}, "z": 10}, +inp: Any = {"a": {"b": {"c": 10}}, "z": 10} + + +class TestSelectFromNestedNode: + @mark.parametrize( + ("key", "expected"), + [ + # all selects are performed on cfg.a: + # relative keys + (".", inp["a"]), + (".b", inp["a"]["b"]), + (".b.c", inp["a"]["b"]["c"]), + ("..", inp), + ("..a", inp["a"]), + ("..a.b", inp["a"]["b"]), + ("..z", inp["z"]), + ], ) - assert OmegaConf.select(cfg.a, ".") == {"b": {"c": 10}} - assert OmegaConf.select(cfg.a, "..") == {"a": {"b": {"c": 10}}, "z": 10} - assert OmegaConf.select(cfg.a, "..a") == {"b": {"c": 10}} - assert OmegaConf.select(cfg.a, "..z") == 10 + def test_select_from_nested_node_with_a_relative_key( + self, key: str, expected: Any + ) -> None: + cfg = OmegaConf.create(inp) + # select returns the same result when a key is relative independent of absolute_key flag. + assert OmegaConf.select(cfg.a, key, absolute_key=False) == expected + assert OmegaConf.select(cfg.a, key, absolute_key=True) == expected + + @mark.parametrize( + ("key", "expected"), + [ + # all selects are performed on cfg.a: + # absolute keys are relative to the calling node + ("", inp["a"]), + ("b", inp["a"]["b"]), + ("b.c", inp["a"]["b"]["c"]), + ], + ) + def test_select_from_nested_node_relative_key_interpretation( + self, key: str, expected: Any + ) -> None: + cfg = OmegaConf.create(inp) + assert OmegaConf.select(cfg.a, key, absolute_key=False) == expected + + @mark.parametrize( + ("key", "expected"), + [ + # all selects are performed on cfg.a: + # absolute keys are relative to the config root + ("", inp), + ("a", inp["a"]), + ("a.b", inp["a"]["b"]), + ("a.b.c", inp["a"]["b"]["c"]), + ("z", inp["z"]), + ], + ) + def test_select_from_nested_node_absolute_key_interpretation( + self, key: str, expected: Any + ) -> None: + cfg = OmegaConf.create(inp) + assert OmegaConf.select(cfg.a, key, absolute_key=True) == expected