Skip to content

Commit

Permalink
Add new resolvers oc.dict.keys and oc.dict.values
Browse files Browse the repository at this point in the history
Fixes omry#643
  • Loading branch information
odelalleau committed Apr 10, 2021
1 parent afb7f8d commit 3d77807
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 4 deletions.
28 changes: 28 additions & 0 deletions docs/source/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,34 @@ This can be useful for instance to parse environment variables:
type: int, value: 3308


Extracting lists of keys / values from a dictionary
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Some config options that are stored as a ``DictConfig`` may sometimes be easier to manipulate as lists,
when we care only about the keys or the associated values.

The resolvers ``oc.dict.keys`` and ``oc.dict.values`` simplify such operations by extracting respectively
the list of keys and values from ``dict``-like objects like ``DictConfig``:

.. doctest::

>>> cfg = OmegaConf.create(
... {
... "machines": {
... "node007": "10.0.0.7",
... "node012": "10.0.0.3",
... "node075": "10.0.1.8",
... },
... "nodes": "${oc.dict.keys:${machines}}",
... "ips": "${oc.dict.values:${machines}}",
... }
... )
>>> show(cfg.nodes)
type: ListConfig, value: ['node007', 'node012', 'node075']
>>> show(cfg.ips)
type: ListConfig, value: ['10.0.0.7', '10.0.0.3', '10.0.1.8']


Custom interpolations
^^^^^^^^^^^^^^^^^^^^^

Expand Down
2 changes: 2 additions & 0 deletions news/643.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
New resolvers `oc.dict.keys` and `oc.dict.values` allow extracting the lists of keys and values of a DictConfig

10 changes: 9 additions & 1 deletion omegaconf/built_in_resolvers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import warnings
from typing import Any, Optional
from typing import Any, Dict, List, Optional

from ._utils import _DEFAULT_MARKER_, _get_value, decode_primitive
from .base import Container
Expand Down Expand Up @@ -28,6 +28,14 @@ def decode(expr: Optional[str], _parent_: Container) -> Any:
return _get_value(val)


def dict_keys(in_dict: Dict[Any, Any]) -> List[Any]:
return list(in_dict.keys())


def dict_values(in_dict: Dict[Any, Any]) -> List[Any]:
return list(in_dict.values())


def env(key: str, default: Any = _DEFAULT_MARKER_) -> Optional[str]:
"""
:param key: Environment variable key
Expand Down
9 changes: 6 additions & 3 deletions omegaconf/omegaconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,14 @@ def SI(interpolation: str) -> Any:


def register_default_resolvers() -> None:
from .built_in_resolvers import decode, env, legacy_env
from .built_in_resolvers import decode, dict_keys, dict_values, env, legacy_env

OmegaConf.register_new_resolver("oc.decode", decode)
OmegaConf.register_new_resolver("oc.dict.keys", dict_keys)
OmegaConf.register_new_resolver("oc.dict.values", dict_values)
OmegaConf.register_new_resolver("oc.env", env)

OmegaConf.legacy_register_resolver("env", legacy_env)
OmegaConf.register_new_resolver("oc.env", env, use_cache=False)
OmegaConf.register_new_resolver("oc.decode", decode, use_cache=False)


class OmegaConf:
Expand Down
87 changes: 87 additions & 0 deletions tests/interpolation/built_in_resolvers/test_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from typing import Any

from pytest import mark, param

from omegaconf import OmegaConf


@mark.parametrize(
("cfg", "key", "expected"),
[
param(
{"foo": "${oc.dict.keys:{a: 0, b: 1}}"},
"foo",
OmegaConf.create(["a", "b"]),
id="dict",
),
param(
{"foo": "${oc.dict.keys:${bar}}", "bar": {"a": 0, "b": 1}},
"foo",
OmegaConf.create(["a", "b"]),
id="dictconfig",
),
param(
{"foo": "${sum:${oc.dict.keys:{1: one, 2: two}}}"},
"foo",
3,
id="nested",
),
],
)
def test_dict_keys(restore_resolvers: Any, cfg: Any, key: Any, expected: Any) -> None:
OmegaConf.register_new_resolver("sum", lambda x: sum(x))

cfg = OmegaConf.create(cfg)
val = cfg[key]
assert val == expected
assert type(val) is type(expected)


@mark.parametrize(
("cfg", "key", "expected"),
[
param(
{"foo": "${oc.dict.values:{a: 0, b: 1}}"},
"foo",
OmegaConf.create([0, 1]),
id="dict",
),
param(
{"foo": "${oc.dict.values:${bar}}", "bar": {"a": 0, "b": 1}},
"foo",
OmegaConf.create([0, 1]),
id="dictconfig",
),
param(
{"foo": "${sum:${oc.dict.values:{one: 1, two: 2}}}"},
"foo",
3,
id="nested",
),
param(
{
"foo": "${oc.dict.values:${bar}}",
"bar": {"x": {"x0": 0, "x1": 1}, "y": {"y0": 0}},
},
"foo",
OmegaConf.create([{"x0": 0, "x1": 1}, {"y0": 0}]),
id="convert_node_to_list",
),
param(
{
"foo": "${oc.dict.values:{key: ${val_ref}}}",
"val_ref": "value",
},
"foo",
OmegaConf.create(["value"]),
id="dict_with_interpolated_value",
),
],
)
def test_dict_values(restore_resolvers: Any, cfg: Any, key: Any, expected: Any) -> None:
OmegaConf.register_new_resolver("sum", lambda x: sum(x))

cfg = OmegaConf.create(cfg)
val = cfg[key]
assert val == expected
assert type(val) is type(expected)

0 comments on commit 3d77807

Please sign in to comment.