Skip to content

Commit

Permalink
Add new resolver oc.create
Browse files Browse the repository at this point in the history
  • Loading branch information
odelalleau committed Apr 16, 2021
1 parent abc60ae commit f5fb0cc
Show file tree
Hide file tree
Showing 7 changed files with 183 additions and 21 deletions.
36 changes: 23 additions & 13 deletions docs/source/custom_resolvers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,19 +60,6 @@ simply use quotes to bypass character limitations in strings.
'Hello, World'


Custom resolvers can return lists or dictionaries, that are automatically converted into DictConfig and ListConfig:

.. doctest::

>>> OmegaConf.register_new_resolver(
... "min_max", lambda *a: {"min": min(a), "max": max(a)}
... )
>>> c = OmegaConf.create({'stats': '${min_max: -1, 3, 2, 5, -10}'})
>>> assert isinstance(c.stats, DictConfig)
>>> c.stats.min, c.stats.max
(-10, 5)


You can take advantage of nested interpolations to perform custom operations over variables:

.. doctest::
Expand Down Expand Up @@ -213,6 +200,29 @@ The following example falls back to default passwords when ``DB_PASSWORD`` is no
>>> show(cfg.database.password3)
type: NoneType, value: None


.. _oc.create:

oc.create
^^^^^^^^^

``oc.create`` may be used for dynamic generation of config nodes
(typically from Python ``dict`` / ``list`` objects or YAML strings, similar to :ref:`OmegaConf.create<creating>`).
The following example combines ``oc.create`` with ``oc.decode`` and ``oc.env`` to generate
a sub-config from an environment variable:

.. doctest::

>>> cfg = OmegaConf.create(
... {
... "model": "${oc.create:${oc.decode:${oc.env:MODEL}}}",
... }
... )
>>> os.environ["MODEL"] = "{name: my_model, layer_size: [100, 200]}"
>>> show(cfg.model.layer_size)
type: ListConfig, value: [100, 200]


.. _oc.deprecated:

oc.deprecated
Expand Down
3 changes: 3 additions & 0 deletions docs/source/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ Just pip install::

OmegaConf requires Python 3.6 and newer.

.. _creating:

Creating
--------
You can create OmegaConf objects from multiple sources.
Expand Down Expand Up @@ -401,6 +403,7 @@ Built-in resolvers
^^^^^^^^^^^^^^^^^^
OmegaConf comes with a set of built-in custom resolvers:

* :ref:`oc.create`: Dynamically generating config nodes
* :ref:`oc.decode`: Parsing an input string using interpolation grammar
* :ref:`oc.deprecated`: Deprecate a key in your config
* :ref:`oc.env`: Accessing environment variables
Expand Down
1 change: 1 addition & 0 deletions news/645.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
The new built-in resolver `oc.create` can be used to dynamically generate config nodes
1 change: 1 addition & 0 deletions omegaconf/omegaconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def SI(interpolation: str) -> Any:
def register_default_resolvers() -> None:
from omegaconf.resolvers import env, oc

OmegaConf.register_new_resolver("oc.create", oc.create)
OmegaConf.register_new_resolver("oc.decode", oc.decode)
OmegaConf.register_new_resolver("oc.deprecated", oc.deprecated)
OmegaConf.register_new_resolver("oc.env", oc.env)
Expand Down
14 changes: 14 additions & 0 deletions omegaconf/resolvers/oc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,24 @@

from omegaconf import Container, Node
from omegaconf._utils import _DEFAULT_MARKER_, _get_value
from omegaconf.basecontainer import BaseContainer
from omegaconf.errors import ConfigKeyError
from omegaconf.grammar_parser import parse
from omegaconf.resolvers.oc import dict


def create(obj: Any, _parent_: Container) -> Any:
"""Create a config object from `obj`, similar to `OmegaConf.create`"""
from omegaconf import OmegaConf

assert isinstance(_parent_, BaseContainer)
ret = OmegaConf.create(obj, parent=_parent_)
# Since this node is re-generated on-the-fly, changes would be lost: we mark it
# as read-only to avoid mistakes.
ret._set_flag("readonly", True)
return ret


def env(key: str, default: Any = _DEFAULT_MARKER_) -> Optional[str]:
"""
:param key: Environment variable key
Expand Down Expand Up @@ -85,6 +98,7 @@ def deprecated(


__all__ = [
"create",
"decode",
"deprecated",
"dict",
Expand Down
134 changes: 134 additions & 0 deletions tests/interpolation/built_in_resolvers/test_create_resolver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
from typing import Any, Dict, List

from pytest import mark, param, raises

from omegaconf import DictConfig, ListConfig, OmegaConf, flag_override
from omegaconf.errors import InterpolationResolutionError


@mark.parametrize(
("cfg", "key", "expected"),
[
# Note that since `oc.create` is simply calling `OmegaConf.create`, which is
# already thoroughly tested, we do not do extensive tests here.
param(
{"x": "${oc.create:{a: 0, b: 1}}"},
"x",
OmegaConf.create({"a": 0, "b": 1}),
id="dict",
),
param(
{"x": "${oc.create:[0, 1, 2]}"},
"x",
OmegaConf.create([0, 1, 2]),
id="list",
),
param(
{"x": "${oc.create:{a: 0, b: ${y}}}", "y": 5},
"x",
OmegaConf.create({"a": 0, "b": 5}),
id="dict:interpolated_value",
),
param(
{"x": "${oc.create:[0, 1, ${y}]}", "y": 5},
"x",
OmegaConf.create([0, 1, 5]),
id="list:interpolated_value",
),
param(
{"x": "${oc.create:${y}}", "y": {"a": 0}},
"x",
OmegaConf.create({"a": 0}),
id="dict:interpolated_node",
),
param(
{"x": "${oc.create:${y}}", "y": [0, 1]},
"x",
OmegaConf.create([0, 1]),
id="list:interpolated_node",
),
],
)
def test_create(cfg: Any, key: str, expected: Any) -> None:
cfg = OmegaConf.create(cfg)
val = cfg[key]
assert val == expected
assert type(val) is type(expected)
assert val._get_flag("readonly")


def test_create_error() -> None:
cfg = OmegaConf.create({"x": "${oc.create:0}"})
with raises(InterpolationResolutionError, match="ValidationError"):
cfg.x


def test_write_into_output() -> None:
cfg = OmegaConf.create(
{
"x": "${oc.create:${y}}",
"y": {
"a": 0,
"b": {"c": 1},
},
}
)
x = cfg.x
assert x._get_flag("readonly")

# "Force-write" into the node generated by `oc.create`.
with flag_override(x, "readonly", False):
x.a = 1
x.b.c = 2

# The node that we force-wrote into should be modified.
assert x.a == 1
assert x.b.c == 2

# The interpolated node should not be modified.
assert cfg.y.a == 0
assert cfg.y.b.c == 1

# Re-accessing the node "forgets" the changes.
assert cfg.x.a == 0
assert cfg.x.b.c == 1


@mark.parametrize(
("cfg", "expected"),
[
({"a": 0, "b": 1}, {"a": 0, "b": 1}),
({"a": "${y}"}, {"a": -1}),
({"a": 0, "b": "${x.a}"}, {"a": 0, "b": 0}),
({"a": 0, "b": "${.a}"}, {"a": 0, "b": 0}),
({"a": "${..y}"}, {"a": -1}),
],
)
def test_resolver_output_dict_to_dictconfig(
restore_resolvers: Any, cfg: Dict[str, Any], expected: Dict[str, Any]
) -> None:
OmegaConf.register_new_resolver("dict", lambda: cfg)
c = OmegaConf.create({"x": "${oc.create:${dict:}}", "y": -1})
assert isinstance(c.x, DictConfig)
assert c.x == expected
assert c.x._parent is c


@mark.parametrize(
("cfg", "expected"),
[
([0, 1], [0, 1]),
(["${y}"], [-1]),
([0, "${x.0}"], [0, 0]),
([0, "${.0}"], [0, 0]),
(["${..y}"], [-1]),
],
)
def test_resolver_output_list_to_listconfig(
restore_resolvers: Any, cfg: List[Any], expected: List[Any]
) -> None:
OmegaConf.register_new_resolver("list", lambda: cfg)
c = OmegaConf.create({"x": "${oc.create:${list:}}", "y": -1})
assert isinstance(c.x, ListConfig)
assert c.x == expected
assert c.x._parent is c
15 changes: 7 additions & 8 deletions tests/interpolation/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,14 +278,14 @@ def test_none_value_in_quoted_string(restore_resolvers: Any) -> None:
id="convert_str_to_int",
),
param(
MissingList(list=SI("${identity:[a, b, c]}")), # BROKEN
MissingList(list=SI("${oc.create:[a, b, c]}")),
"list",
["a", "b", "c"],
ListConfig,
id="list_str",
),
param(
MissingDict(dict=SI("${identity:{key1: val1, key2: val2}}")), # BROKEN
MissingDict(dict=SI("${oc.create:{key1: val1, key2: val2}}")),
"dict",
{"key1": "val1", "key2": "val2"},
DictConfig,
Expand Down Expand Up @@ -372,28 +372,28 @@ def test_interpolation_type_validated_error(
("cfg", "key", "expected_value", "expected_node_type"),
[
param(
MissingList(list=SI("${identity:[0, 1, 2]}")),
MissingList(list=SI("${oc.create:[0, 1, 2]}")),
"list",
[0, 1, 2],
ListConfig,
id="list_int_to_str",
),
param(
MissingDict(dict=SI("${identity:{a: 0, b: 1}}")),
MissingDict(dict=SI("${oc.create:{a: 0, b: 1}}")),
"dict",
{"a": 0, "b": 1},
DictConfig,
id="dict_int_to_str",
),
param(
SubscriptedList(list=SI("${identity:[a, b]}")),
SubscriptedList(list=SI("${oc.create:[a, b]}")),
"list",
["a", "b"],
ListConfig,
id="list_type_mismatch",
),
param(
MissingDict(dict=SI("${identity:{0: b, 1: d}}")),
MissingDict(dict=SI("${oc.create:{0: b, 1: d}}")),
"dict",
{0: "b", 1: "d"},
DictConfig,
Expand All @@ -402,7 +402,6 @@ def test_interpolation_type_validated_error(
],
)
def test_interpolation_type_not_validated(
common_resolvers: Any,
cfg: Any,
key: str,
expected_value: Any,
Expand All @@ -415,7 +414,7 @@ def test_interpolation_type_not_validated(

node = cfg._get_node(key)
assert isinstance(node, Node)
# assert isinstance(node._dereference_node(), expected_node_type) BROKEN
assert isinstance(node._dereference_node(), expected_node_type)


def test_type_validation_error_no_throw() -> None:
Expand Down

0 comments on commit f5fb0cc

Please sign in to comment.