Skip to content

Commit

Permalink
Support optional _parent_ and _root_ parameters in custom resolvers (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
omry authored Mar 13, 2021
1 parent 3bf2c59 commit 0504958
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 6 deletions.
38 changes: 35 additions & 3 deletions docs/source/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -489,14 +489,13 @@ You can take advantage of nested interpolations to perform custom operations ove

.. doctest::

>>> OmegaConf.register_new_resolver("plus", lambda x, y: x + y)
>>> OmegaConf.register_new_resolver("sum", lambda x, y: x + y)
>>> c = OmegaConf.create({"a": 1,
... "b": 2,
... "a_plus_b": "${plus:${a},${b}}"})
... "a_plus_b": "${sum:${a},${b}}"})
>>> c.a_plus_b
3


More advanced resolver naming features include the ability to prefix a resolver name with a
namespace, and to use interpolations in the name itself. The following example demonstrates both:

Expand Down Expand Up @@ -531,6 +530,39 @@ inputs we always return the same value. This behavior may be disabled by setting
>>> assert c.uncached != c.uncached



Custom interpolations can also receive the following special parameters:

- ``_parent_`` : the parent node of an interpolation.
- ``_root_``: The config root.

This can be achieved by adding the special parameters to the resolver signature.
Note that special parameters must be defined as named keywords (after the `*`):

In this example, we use ``_parent_`` to implement a sum function that defaults to 0 if the node does not exist.
(In contrast to the sum we defined earlier where accessing an invalid key, e.g. ``"a_plus_z": ${sum:${a}, ${z}}`` will result in an error).

.. doctest::

>>> def sum2(a, b, *, _parent_):
... return _parent_.get(a, 0) + _parent_.get(b, 0)
>>> OmegaConf.register_new_resolver("sum2", sum2, use_cache=False)
>>> cfg = OmegaConf.create(
... {
... "node": {
... "a": 1,
... "b": 2,
... "a_plus_b": "${sum2:a,b}",
... "a_plus_z": "${sum2:a,z}",
... },
... }
... )
>>> cfg.node.a_plus_b
3
>>> cfg.node.a_plus_z
1


Merging configurations
----------------------
Merging configurations enables the creation of reusable configuration files for each logical component
Expand Down
1 change: 1 addition & 0 deletions news/266.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Custom resolvers can now access the parent and the root config nodes
2 changes: 1 addition & 1 deletion omegaconf/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@ def _evaluate_custom_resolver(
resolver = OmegaConf.get_resolver(inter_type)
if resolver is not None:
root_node = self._get_root()
return resolver(root_node, inter_args, inter_args_str)
return resolver(root_node, self, inter_args, inter_args_str)
else:
raise UnsupportedInterpolationType(
f"Unsupported interpolation type {inter_type}"
Expand Down
29 changes: 27 additions & 2 deletions omegaconf/omegaconf.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""OmegaConf module"""
import copy
import inspect
import io
import os
import pathlib
Expand Down Expand Up @@ -431,6 +432,7 @@ def legacy_register_resolver(name: str, resolver: Resolver) -> None:

def resolver_wrapper(
config: BaseContainer,
node: BaseContainer,
args: Tuple[Any, ...],
args_str: Tuple[str, ...],
) -> Any:
Expand Down Expand Up @@ -484,8 +486,22 @@ def register_new_resolver(
name not in BaseContainer._resolvers
), "resolver {} is already registered".format(name)

sig = inspect.signature(resolver)

def _should_pass(special: str) -> bool:
ret = special in sig.parameters
if ret and use_cache:
raise ValueError(
f"use_cache=True is incompatible with functions that receive the {special}"
)
return ret

pass_parent = _should_pass("_parent_")
pass_root = _should_pass("_root_")

def resolver_wrapper(
config: BaseContainer,
parent: Container,
args: Tuple[Any, ...],
args_str: Tuple[str, ...],
) -> Any:
Expand All @@ -498,7 +514,14 @@ def resolver_wrapper(
pass

# Call resolver.
ret = resolver(*args)
kwargs = {}
if pass_parent:
kwargs["_parent_"] = parent
if pass_root:
kwargs["_root_"] = config

ret = resolver(*args, **kwargs)

if use_cache:
cache[hashable_key] = ret
return ret
Expand All @@ -509,7 +532,9 @@ def resolver_wrapper(
@staticmethod
def get_resolver(
name: str,
) -> Optional[Callable[[Container, Tuple[Any, ...], Tuple[str, ...]], Any]]:
) -> Optional[
Callable[[Container, Container, Tuple[Any, ...], Tuple[str, ...]], Any]
]:
# noinspection PyProtectedMember
return (
BaseContainer._resolvers[name] if name in BaseContainer._resolvers else None
Expand Down
83 changes: 83 additions & 0 deletions tests/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -987,3 +987,86 @@ def test_resolver_output_list_to_listconfig(
assert isinstance(c.x, ListConfig)
assert c.x == expected
assert dereference(c, "x")._get_flag("readonly")


def test_register_cached_resolver_with_keyword_unsupported() -> None:
with pytest.raises(ValueError):
OmegaConf.register_new_resolver("root", lambda _root_: None, use_cache=True)
with pytest.raises(ValueError):
OmegaConf.register_new_resolver("parent", lambda _parent_: None, use_cache=True)


def test_resolver_with_parent(restore_resolvers: Any) -> None:
OmegaConf.register_new_resolver(
"parent", lambda _parent_: _parent_, use_cache=False
)

cfg = OmegaConf.create(
{
"a": 10,
"b": {
"c": 20,
"parent": "${parent:}",
},
"parent": "${parent:}",
}
)

assert cfg.parent is cfg
assert cfg.b.parent is cfg.b


def test_resolver_with_root(restore_resolvers: Any) -> None:
OmegaConf.register_new_resolver("root", lambda _root_: _root_, use_cache=False)
cfg = OmegaConf.create(
{
"a": 10,
"b": {
"c": 20,
"root": "${root:}",
},
"root": "${root:}",
}
)

assert cfg.root is cfg
assert cfg.b.root is cfg


def test_resolver_with_root_and_parent(restore_resolvers: Any) -> None:
OmegaConf.register_new_resolver(
"both", lambda _root_, _parent_: _root_.add + _parent_.add, use_cache=False
)

cfg = OmegaConf.create(
{
"add": 10,
"b": {
"add": 20,
"both": "${both:}",
},
"both": "${both:}",
}
)
assert cfg.both == 20
assert cfg.b.both == 30


def test_resolver_with_parent_and_default_value(restore_resolvers: Any) -> None:
def parent_and_default(default: int = 10, *, _parent_: Any) -> Any:
return _parent_.add + default

OmegaConf.register_new_resolver(
"parent_and_default", parent_and_default, use_cache=False
)

cfg = OmegaConf.create(
{
"add": 10,
"no_param": "${parent_and_default:}",
"param": "${parent_and_default:20}",
}
)

assert cfg.no_param == 20
assert cfg.param == 30

0 comments on commit 0504958

Please sign in to comment.