Skip to content

Commit

Permalink
Improvements to oc.dict.{keys,values}
Browse files Browse the repository at this point in the history
* Can now take a string as input for convenience
* The output is always a ListConfig, whose parent is the parent of the
  node being processed
  • Loading branch information
odelalleau committed Mar 24, 2021
1 parent cea35d4 commit 872eb2d
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 10 deletions.
7 changes: 5 additions & 2 deletions docs/source/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,8 @@ Some config options that are stored as a ``DictConfig`` may sometimes be easier
when we care only about the keys or the associated values.

The resolvers ``oc.dict.keys`` and ``oc.dict.values`` simplify such operations by extracting respectively
the list of keys and values from ``dict``-like objects like ``DictConfig``:
the list of keys and values from ``dict``-like objects like ``DictConfig``.
If a string is given as input, ``OmegaConf.select()`` is used to access the corresponding config node.

.. doctest::

Expand All @@ -476,8 +477,10 @@ the list of keys and values from ``dict``-like objects like ``DictConfig``:
... "node012": "10.0.0.3",
... "node075": "10.0.1.8",
... },
... # Explicit interpolation `${machines}` as input.
... "nodes": "${oc.dict.keys:${machines}}",
... "ips": "${oc.dict.values:${machines}}",
... # Config node name `machines` as input.
... "ips": "${oc.dict.values:machines}",
... }
... )
>>> show(cfg.nodes)
Expand Down
48 changes: 43 additions & 5 deletions omegaconf/built_in_resolvers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import os
import warnings
from typing import Any, Dict, List, Optional

# from collections.abc import Mapping, MutableMapping
from typing import Any, Mapping, Optional, Union

from ._utils import _DEFAULT_MARKER_, _get_value, decode_primitive
from .base import Container
from .basecontainer import BaseContainer
from .errors import ValidationError
from .grammar_parser import parse
from .listconfig import ListConfig
from .omegaconf import OmegaConf


def decode(expr: Optional[str], _parent_: Container) -> Any:
Expand All @@ -28,12 +33,22 @@ def decode(expr: Optional[str], _parent_: Container) -> Any:
return _get_value(val)


def dict_keys(in_dict: Dict[Any, Any]) -> List[Any]:
return list(in_dict.keys())
def dict_keys(
in_dict: Union[str, Mapping[Any, Any]],
_root_: BaseContainer,
_parent_: Container,
) -> ListConfig:
return _dict_impl(
keys_or_values="keys", in_dict=in_dict, _root_=_root_, _parent_=_parent_
)


def dict_values(in_dict: Dict[Any, Any]) -> List[Any]:
return list(in_dict.values())
def dict_values(
in_dict: Union[str, Mapping[Any, Any]], _root_: BaseContainer, _parent_: Container
) -> ListConfig:
return _dict_impl(
keys_or_values="values", in_dict=in_dict, _root_=_root_, _parent_=_parent_
)


def env(key: str, default: Optional[str] = _DEFAULT_MARKER_) -> Optional[str]:
Expand Down Expand Up @@ -69,3 +84,26 @@ def legacy_env(key: str, default: Optional[str] = None) -> Any:
return decode_primitive(default)
else:
raise ValidationError(f"Environment variable '{key}' not found")


def _dict_impl(
keys_or_values: str,
in_dict: Union[str, Mapping[Any, Any]],
_root_: BaseContainer,
_parent_: Container,
) -> ListConfig:
if isinstance(in_dict, str):
# Path to an existing key in the config: use `select()`.
in_dict = OmegaConf.select(_root_, in_dict, throw_on_missing=True)

if not isinstance(in_dict, Mapping):
raise TypeError(
f"`oc.dict.{keys_or_values}` cannot be applied to objects of type: "
f"{type(in_dict).__name__}"
)

dict_method = getattr(in_dict, keys_or_values)
assert isinstance(_parent_, BaseContainer)
ret = OmegaConf.create(list(dict_method()), parent=_parent_)
assert isinstance(ret, ListConfig)
return ret
55 changes: 52 additions & 3 deletions tests/interpolation/built_in_resolvers/test_dict.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Any

from pytest import mark, param
from pytest import mark, param, raises

from omegaconf import OmegaConf
from omegaconf import ListConfig, OmegaConf
from omegaconf.errors import InterpolationResolutionError


@mark.parametrize(
Expand All @@ -18,7 +19,13 @@
{"foo": "${oc.dict.keys:${bar}}", "bar": {"a": 0, "b": 1}},
"foo",
OmegaConf.create(["a", "b"]),
id="dictconfig",
id="dictconfig_interpolation",
),
param(
{"foo": "${oc.dict.keys:bar}", "bar": {"a": 0, "b": 1}},
"foo",
OmegaConf.create(["a", "b"]),
id="dictconfig_select",
),
param(
{"foo": "${sum:${oc.dict.keys:{1: one, 2: two}}}"},
Expand All @@ -36,6 +43,9 @@ def test_dict_keys(restore_resolvers: Any, cfg: Any, key: Any, expected: Any) ->
assert val == expected
assert type(val) is type(expected)

if isinstance(val, ListConfig):
assert val._parent is cfg


@mark.parametrize(
("cfg", "key", "expected"),
Expand All @@ -52,6 +62,12 @@ def test_dict_keys(restore_resolvers: Any, cfg: Any, key: Any, expected: Any) ->
OmegaConf.create([0, 1]),
id="dictconfig",
),
param(
{"foo": "${oc.dict.values:bar}", "bar": {"a": 0, "b": 1}},
"foo",
OmegaConf.create([0, 1]),
id="dictconfig_select",
),
param(
{"foo": "${sum:${oc.dict.values:{one: 1, two: 2}}}"},
"foo",
Expand Down Expand Up @@ -85,3 +101,36 @@ def test_dict_values(restore_resolvers: Any, cfg: Any, key: Any, expected: Any)
val = cfg[key]
assert val == expected
assert type(val) is type(expected)

if isinstance(val, ListConfig):
assert val._parent is cfg


@mark.parametrize(
"cfg",
[
param({"x": "${oc.dict.keys:[]}"}, id="list"),
param({"x": "${oc.dict.keys:${bool}}", "bool": True}, id="bool_interpolation"),
param({"x": "${oc.dict.keys:int}", "int": 0}, id="int_select"),
],
)
def test_dict_keys_invalid_type(cfg: Any) -> None:
cfg = OmegaConf.create(cfg)
with raises(InterpolationResolutionError, match="TypeError"):
cfg.x


@mark.parametrize(
"cfg",
[
param({"x": "${oc.dict.values:[]}"}, id="list"),
param(
{"x": "${oc.dict.values:${bool}}", "bool": True}, id="bool_interpolation"
),
param({"x": "${oc.dict.values:int}", "int": 0}, id="int_select"),
],
)
def test_dict_values_invalid_type(cfg: Any) -> None:
cfg = OmegaConf.create(cfg)
with raises(InterpolationResolutionError, match="TypeError"):
cfg.x

0 comments on commit 872eb2d

Please sign in to comment.