From 3d778073c3355d9063e96807aa210ebad7328008 Mon Sep 17 00:00:00 2001 From: Olivier Delalleau <507137+odelalleau@users.noreply.github.com> Date: Wed, 24 Mar 2021 15:16:30 -0400 Subject: [PATCH] Add new resolvers `oc.dict.keys` and `oc.dict.values` Fixes #643 --- docs/source/usage.rst | 28 ++++++ news/643.feature | 2 + omegaconf/built_in_resolvers.py | 10 ++- omegaconf/omegaconf.py | 9 +- .../built_in_resolvers/test_dict.py | 87 +++++++++++++++++++ 5 files changed, 132 insertions(+), 4 deletions(-) create mode 100644 news/643.feature create mode 100644 tests/interpolation/built_in_resolvers/test_dict.py diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 0e56495e0..da9b16bb2 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -466,6 +466,34 @@ This can be useful for instance to parse environment variables: type: int, value: 3308 +Extracting lists of keys / values from a dictionary +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Some config options that are stored as a ``DictConfig`` may sometimes be easier to manipulate as lists, +when we care only about the keys or the associated values. + +The resolvers ``oc.dict.keys`` and ``oc.dict.values`` simplify such operations by extracting respectively +the list of keys and values from ``dict``-like objects like ``DictConfig``: + +.. doctest:: + + >>> cfg = OmegaConf.create( + ... { + ... "machines": { + ... "node007": "10.0.0.7", + ... "node012": "10.0.0.3", + ... "node075": "10.0.1.8", + ... }, + ... "nodes": "${oc.dict.keys:${machines}}", + ... "ips": "${oc.dict.values:${machines}}", + ... } + ... ) + >>> show(cfg.nodes) + type: ListConfig, value: ['node007', 'node012', 'node075'] + >>> show(cfg.ips) + type: ListConfig, value: ['10.0.0.7', '10.0.0.3', '10.0.1.8'] + + Custom interpolations ^^^^^^^^^^^^^^^^^^^^^ diff --git a/news/643.feature b/news/643.feature new file mode 100644 index 000000000..78143216d --- /dev/null +++ b/news/643.feature @@ -0,0 +1,2 @@ +New resolvers `oc.dict.keys` and `oc.dict.values` allow extracting the lists of keys and values of a DictConfig + diff --git a/omegaconf/built_in_resolvers.py b/omegaconf/built_in_resolvers.py index 373c01a4d..f253c669c 100644 --- a/omegaconf/built_in_resolvers.py +++ b/omegaconf/built_in_resolvers.py @@ -1,6 +1,6 @@ import os import warnings -from typing import Any, Optional +from typing import Any, Dict, List, Optional from ._utils import _DEFAULT_MARKER_, _get_value, decode_primitive from .base import Container @@ -28,6 +28,14 @@ def decode(expr: Optional[str], _parent_: Container) -> Any: return _get_value(val) +def dict_keys(in_dict: Dict[Any, Any]) -> List[Any]: + return list(in_dict.keys()) + + +def dict_values(in_dict: Dict[Any, Any]) -> List[Any]: + return list(in_dict.values()) + + def env(key: str, default: Any = _DEFAULT_MARKER_) -> Optional[str]: """ :param key: Environment variable key diff --git a/omegaconf/omegaconf.py b/omegaconf/omegaconf.py index 731da2f68..b2390c193 100644 --- a/omegaconf/omegaconf.py +++ b/omegaconf/omegaconf.py @@ -93,11 +93,14 @@ def SI(interpolation: str) -> Any: def register_default_resolvers() -> None: - from .built_in_resolvers import decode, env, legacy_env + from .built_in_resolvers import decode, dict_keys, dict_values, env, legacy_env + + OmegaConf.register_new_resolver("oc.decode", decode) + OmegaConf.register_new_resolver("oc.dict.keys", dict_keys) + OmegaConf.register_new_resolver("oc.dict.values", dict_values) + OmegaConf.register_new_resolver("oc.env", env) OmegaConf.legacy_register_resolver("env", legacy_env) - OmegaConf.register_new_resolver("oc.env", env, use_cache=False) - OmegaConf.register_new_resolver("oc.decode", decode, use_cache=False) class OmegaConf: diff --git a/tests/interpolation/built_in_resolvers/test_dict.py b/tests/interpolation/built_in_resolvers/test_dict.py new file mode 100644 index 000000000..b01a0f117 --- /dev/null +++ b/tests/interpolation/built_in_resolvers/test_dict.py @@ -0,0 +1,87 @@ +from typing import Any + +from pytest import mark, param + +from omegaconf import OmegaConf + + +@mark.parametrize( + ("cfg", "key", "expected"), + [ + param( + {"foo": "${oc.dict.keys:{a: 0, b: 1}}"}, + "foo", + OmegaConf.create(["a", "b"]), + id="dict", + ), + param( + {"foo": "${oc.dict.keys:${bar}}", "bar": {"a": 0, "b": 1}}, + "foo", + OmegaConf.create(["a", "b"]), + id="dictconfig", + ), + param( + {"foo": "${sum:${oc.dict.keys:{1: one, 2: two}}}"}, + "foo", + 3, + id="nested", + ), + ], +) +def test_dict_keys(restore_resolvers: Any, cfg: Any, key: Any, expected: Any) -> None: + OmegaConf.register_new_resolver("sum", lambda x: sum(x)) + + cfg = OmegaConf.create(cfg) + val = cfg[key] + assert val == expected + assert type(val) is type(expected) + + +@mark.parametrize( + ("cfg", "key", "expected"), + [ + param( + {"foo": "${oc.dict.values:{a: 0, b: 1}}"}, + "foo", + OmegaConf.create([0, 1]), + id="dict", + ), + param( + {"foo": "${oc.dict.values:${bar}}", "bar": {"a": 0, "b": 1}}, + "foo", + OmegaConf.create([0, 1]), + id="dictconfig", + ), + param( + {"foo": "${sum:${oc.dict.values:{one: 1, two: 2}}}"}, + "foo", + 3, + id="nested", + ), + param( + { + "foo": "${oc.dict.values:${bar}}", + "bar": {"x": {"x0": 0, "x1": 1}, "y": {"y0": 0}}, + }, + "foo", + OmegaConf.create([{"x0": 0, "x1": 1}, {"y0": 0}]), + id="convert_node_to_list", + ), + param( + { + "foo": "${oc.dict.values:{key: ${val_ref}}}", + "val_ref": "value", + }, + "foo", + OmegaConf.create(["value"]), + id="dict_with_interpolated_value", + ), + ], +) +def test_dict_values(restore_resolvers: Any, cfg: Any, key: Any, expected: Any) -> None: + OmegaConf.register_new_resolver("sum", lambda x: sum(x)) + + cfg = OmegaConf.create(cfg) + val = cfg[key] + assert val == expected + assert type(val) is type(expected)