diff --git a/omegaconf/omegaconf.py b/omegaconf/omegaconf.py index bad25b392..c2dc08137 100644 --- a/omegaconf/omegaconf.py +++ b/omegaconf/omegaconf.py @@ -679,6 +679,7 @@ def select( ) -> Any: try: try: + cfg, key = cfg._resolve_key_and_root(key) _root, _last_key, value = cfg._select_impl( key, throw_on_missing=throw_on_missing, diff --git a/tests/interpolation/test_interpolation.py b/tests/interpolation/test_interpolation.py index 29502b04b..bc4a1eb29 100644 --- a/tests/interpolation/test_interpolation.py +++ b/tests/interpolation/test_interpolation.py @@ -3,7 +3,6 @@ from textwrap import dedent from typing import Any, Tuple -from _pytest.python_api import RaisesContext from pytest import mark, param, raises from omegaconf import ( @@ -32,65 +31,6 @@ from tests.interpolation import dereference_node -@mark.parametrize( - "cfg,key,expected", - [ - param({"a": "${b}", "b": 10}, "a", 10, id="simple"), - param( - {"a": "${x}"}, - "a", - raises(InterpolationKeyError), - id="not_found", - ), - param( - {"a": "${x.y}"}, - "a", - raises(InterpolationKeyError), - id="not_found", - ), - param({"a": "foo_${b}", "b": "bar"}, "a", "foo_bar", id="str_inter"), - param( - {"a": "${x}_${y}", "x": "foo", "y": "bar"}, - "a", - "foo_bar", - id="multi_str_inter", - ), - param({"a": "foo_${b.c}", "b": {"c": 10}}, "a", "foo_10", id="str_deep_inter"), - param({"a": 10, "b": [1, "${a}"]}, "b.1", 10, id="from_list"), - param({"a": "${b}", "b": {"c": 10}}, "a", {"c": 10}, id="dict_val"), - param({"a": "${b}", "b": [1, 2]}, "a", [1, 2], id="list_val"), - param({"a": "${b.1}", "b": [1, 2]}, "a", 2, id="list_index"), - param({"a": "X_${b}", "b": [1, 2]}, "a", "X_[1, 2]", id="liststr"), - param({"a": "X_${b}", "b": {"c": 1}}, "a", "X_{'c': 1}", id="dict_str"), - param({"a": "${b}", "b": "${c}", "c": 10}, "a", 10, id="two_steps"), - param({"bar": 10, "foo": ["${bar}"]}, "foo.0", 10, id="inter_in_list"), - param({"foo": None, "bar": "${foo}"}, "bar", None, id="none"), - param({"list": ["bar"], "foo": "${list.0}"}, "foo", "bar", id="list"), - param( - {"user@domain": 10, "foo": "${user@domain}"}, "foo", 10, id="user@domain" - ), - # relative interpolations - param({"a": "${.b}", "b": 10}, "a", 10, id="relative"), - param({"a": {"z": "${.b}", "b": 10}}, "a.z", 10, id="relative"), - param({"a": {"z": "${..b}"}, "b": 10}, "a.z", 10, id="relative"), - param({"a": {"z": "${..a.b}", "b": 10}}, "a.z", 10, id="relative"), - param( - {"a": "${..b}", "b": 10}, - "a", - raises(InterpolationKeyError), - id="relative", - ), - ], -) -def test_interpolation(cfg: Any, key: str, expected: Any) -> None: - cfg = _ensure_container(cfg) - if isinstance(expected, RaisesContext): - with expected: - OmegaConf.select(cfg, key) - else: - assert OmegaConf.select(cfg, key) == expected - - def test_interpolation_with_missing() -> None: cfg = OmegaConf.create( { diff --git a/tests/test_select.py b/tests/test_select.py index 0b1bd49d2..11c093156 100644 --- a/tests/test_select.py +++ b/tests/test_select.py @@ -41,6 +41,9 @@ class TestSelect: param({"a": {"b": 1}, "c": "one=${a.b}"}, "c", "one=1", id="inter"), param({"a": {"b": "one=${n}"}, "n": 1}, "a.b", "one=1", id="inter"), param({"a": {"b": "one=${func:1}"}}, "a.b", "one=_1_", id="resolver"), + # relative selection + param({"a": {"b": {"c": 10}}}, ".a", {"b": {"c": 10}}, id="relative"), + param({"a": {"b": {"c": 10}}}, ".a.b", {"c": 10}, id="relative"), ], ) def test_select( @@ -189,3 +192,82 @@ def test_select_deprecated(self, struct: Optional[bool]) -> None: ), ): cfg.select("foo") + + def test_select_relative_from_nested_node(self, struct: Optional[bool]) -> None: + cfg = OmegaConf.create( + {"a": {"b": {"c": 10}}, "z": 10}, + ) + OmegaConf.set_struct(cfg, struct) + assert OmegaConf.select(cfg.a, ".") == {"b": {"c": 10}} + assert OmegaConf.select(cfg.a, "..") == {"a": {"b": {"c": 10}}, "z": 10} + assert OmegaConf.select(cfg.a, "..a") == {"b": {"c": 10}} + assert OmegaConf.select(cfg.a, "..z") == 10 + + +@mark.parametrize( + "cfg,key,expected", + [ + param({"a": "${b}", "b": 10}, "a", 10, id="simple"), + param( + {"a": "${x}"}, + "a", + raises(InterpolationKeyError), + id="not_found", + ), + param( + {"a": "${x.y}"}, + "a", + raises(InterpolationKeyError), + id="not_found", + ), + param({"a": "foo_${b}", "b": "bar"}, "a", "foo_bar", id="str_inter"), + param( + {"a": "${x}_${y}", "x": "foo", "y": "bar"}, + "a", + "foo_bar", + id="multi_str_inter", + ), + param({"a": "foo_${b.c}", "b": {"c": 10}}, "a", "foo_10", id="str_deep_inter"), + param({"a": 10, "b": [1, "${a}"]}, "b.1", 10, id="from_list"), + param({"a": "${b}", "b": {"c": 10}}, "a", {"c": 10}, id="dict_val"), + param({"a": "${b}", "b": [1, 2]}, "a", [1, 2], id="list_val"), + param({"a": "${b.1}", "b": [1, 2]}, "a", 2, id="list_index"), + param({"a": "X_${b}", "b": [1, 2]}, "a", "X_[1, 2]", id="liststr"), + param({"a": "X_${b}", "b": {"c": 1}}, "a", "X_{'c': 1}", id="dict_str"), + param({"a": "${b}", "b": "${c}", "c": 10}, "a", 10, id="two_steps"), + param({"bar": 10, "foo": ["${bar}"]}, "foo.0", 10, id="inter_in_list"), + param({"foo": None, "bar": "${foo}"}, "bar", None, id="none"), + param({"list": ["bar"], "foo": "${list.0}"}, "foo", "bar", id="list"), + param( + {"user@domain": 10, "foo": "${user@domain}"}, "foo", 10, id="user@domain" + ), + # relative interpolations + param({"a": "${.b}", "b": 10}, "a", 10, id="relative"), + param({"a": {"z": "${.b}", "b": 10}}, "a.z", 10, id="relative"), + param({"a": {"z": "${..b}"}, "b": 10}, "a.z", 10, id="relative"), + param({"a": {"z": "${..a.b}", "b": 10}}, "a.z", 10, id="relative"), + param( + {"a": "${..b}", "b": 10}, + "a", + raises(InterpolationKeyError), + id="relative", + ), + ], +) +def test_select_resolves_interpolation(cfg: Any, key: str, expected: Any) -> None: + cfg = _ensure_container(cfg) + if isinstance(expected, RaisesContext): + with expected: + OmegaConf.select(cfg, key) + else: + assert OmegaConf.select(cfg, key) == expected + + +def test_select_relative_from_nested_node() -> None: + cfg = OmegaConf.create( + {"a": {"b": {"c": 10}}, "z": 10}, + ) + assert OmegaConf.select(cfg.a, ".") == {"b": {"c": 10}} + assert OmegaConf.select(cfg.a, "..") == {"a": {"b": {"c": 10}}, "z": 10} + assert OmegaConf.select(cfg.a, "..a") == {"b": {"c": 10}} + assert OmegaConf.select(cfg.a, "..z") == 10