diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 8e7f370fc..a4e2e096f 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -489,14 +489,13 @@ You can take advantage of nested interpolations to perform custom operations ove .. doctest:: - >>> OmegaConf.register_new_resolver("plus", lambda x, y: x + y) + >>> OmegaConf.register_new_resolver("sum", lambda x, y: x + y) >>> c = OmegaConf.create({"a": 1, ... "b": 2, - ... "a_plus_b": "${plus:${a},${b}}"}) + ... "a_plus_b": "${sum:${a},${b}}"}) >>> c.a_plus_b 3 - More advanced resolver naming features include the ability to prefix a resolver name with a namespace, and to use interpolations in the name itself. The following example demonstrates both: @@ -531,6 +530,39 @@ inputs we always return the same value. This behavior may be disabled by setting >>> assert c.uncached != c.uncached + +Custom interpolations can also receive the following special parameters: + +- ``_parent_`` : the parent node of an interpolation. +- ``_root_``: The config root. + +This can be achieved by adding the special parameters to the resolver signature. +Note that special parameters must be defined as named keywords (after the `*`): + +In this example, we use ``_parent_`` to implement a sum function that defaults to 0 if the node does not exist. +(In contrast to the sum we defined earlier where accessing an invalid key, e.g. ``"a_plus_z": ${sum:${a}, ${z}}`` will result in an error). + +.. doctest:: + + >>> def sum2(a, b, *, _parent_): + ... return _parent_.get(a, 0) + _parent_.get(b, 0) + >>> OmegaConf.register_new_resolver("sum2", sum2, use_cache=False) + >>> cfg = OmegaConf.create( + ... { + ... "node": { + ... "a": 1, + ... "b": 2, + ... "a_plus_b": "${sum2:a,b}", + ... "a_plus_z": "${sum2:a,z}", + ... }, + ... } + ... ) + >>> cfg.node.a_plus_b + 3 + >>> cfg.node.a_plus_z + 1 + + Merging configurations ---------------------- Merging configurations enables the creation of reusable configuration files for each logical component diff --git a/news/266.feature b/news/266.feature new file mode 100644 index 000000000..f4a6a91f2 --- /dev/null +++ b/news/266.feature @@ -0,0 +1 @@ +Custom resolvers can now access the parent and the root config nodes diff --git a/omegaconf/base.py b/omegaconf/base.py index e4b26c45b..69cd49bb4 100644 --- a/omegaconf/base.py +++ b/omegaconf/base.py @@ -597,7 +597,7 @@ def _evaluate_custom_resolver( resolver = OmegaConf.get_resolver(inter_type) if resolver is not None: root_node = self._get_root() - return resolver(root_node, inter_args, inter_args_str) + return resolver(root_node, self, inter_args, inter_args_str) else: raise UnsupportedInterpolationType( f"Unsupported interpolation type {inter_type}" diff --git a/omegaconf/omegaconf.py b/omegaconf/omegaconf.py index a46b1b619..62dc658bc 100644 --- a/omegaconf/omegaconf.py +++ b/omegaconf/omegaconf.py @@ -1,5 +1,6 @@ """OmegaConf module""" import copy +import inspect import io import os import pathlib @@ -431,6 +432,7 @@ def legacy_register_resolver(name: str, resolver: Resolver) -> None: def resolver_wrapper( config: BaseContainer, + node: BaseContainer, args: Tuple[Any, ...], args_str: Tuple[str, ...], ) -> Any: @@ -484,8 +486,22 @@ def register_new_resolver( name not in BaseContainer._resolvers ), "resolver {} is already registered".format(name) + sig = inspect.signature(resolver) + + def _should_pass(special: str) -> bool: + ret = special in sig.parameters + if ret and use_cache: + raise ValueError( + f"use_cache=True is incompatible with functions that receive the {special}" + ) + return ret + + pass_parent = _should_pass("_parent_") + pass_root = _should_pass("_root_") + def resolver_wrapper( config: BaseContainer, + parent: Container, args: Tuple[Any, ...], args_str: Tuple[str, ...], ) -> Any: @@ -498,7 +514,14 @@ def resolver_wrapper( pass # Call resolver. - ret = resolver(*args) + kwargs = {} + if pass_parent: + kwargs["_parent_"] = parent + if pass_root: + kwargs["_root_"] = config + + ret = resolver(*args, **kwargs) + if use_cache: cache[hashable_key] = ret return ret @@ -509,7 +532,9 @@ def resolver_wrapper( @staticmethod def get_resolver( name: str, - ) -> Optional[Callable[[Container, Tuple[Any, ...], Tuple[str, ...]], Any]]: + ) -> Optional[ + Callable[[Container, Container, Tuple[Any, ...], Tuple[str, ...]], Any] + ]: # noinspection PyProtectedMember return ( BaseContainer._resolvers[name] if name in BaseContainer._resolvers else None diff --git a/tests/test_interpolation.py b/tests/test_interpolation.py index 3947fc3b1..da5b56f12 100644 --- a/tests/test_interpolation.py +++ b/tests/test_interpolation.py @@ -987,3 +987,86 @@ def test_resolver_output_list_to_listconfig( assert isinstance(c.x, ListConfig) assert c.x == expected assert dereference(c, "x")._get_flag("readonly") + + +def test_register_cached_resolver_with_keyword_unsupported() -> None: + with pytest.raises(ValueError): + OmegaConf.register_new_resolver("root", lambda _root_: None, use_cache=True) + with pytest.raises(ValueError): + OmegaConf.register_new_resolver("parent", lambda _parent_: None, use_cache=True) + + +def test_resolver_with_parent(restore_resolvers: Any) -> None: + OmegaConf.register_new_resolver( + "parent", lambda _parent_: _parent_, use_cache=False + ) + + cfg = OmegaConf.create( + { + "a": 10, + "b": { + "c": 20, + "parent": "${parent:}", + }, + "parent": "${parent:}", + } + ) + + assert cfg.parent is cfg + assert cfg.b.parent is cfg.b + + +def test_resolver_with_root(restore_resolvers: Any) -> None: + OmegaConf.register_new_resolver("root", lambda _root_: _root_, use_cache=False) + cfg = OmegaConf.create( + { + "a": 10, + "b": { + "c": 20, + "root": "${root:}", + }, + "root": "${root:}", + } + ) + + assert cfg.root is cfg + assert cfg.b.root is cfg + + +def test_resolver_with_root_and_parent(restore_resolvers: Any) -> None: + OmegaConf.register_new_resolver( + "both", lambda _root_, _parent_: _root_.add + _parent_.add, use_cache=False + ) + + cfg = OmegaConf.create( + { + "add": 10, + "b": { + "add": 20, + "both": "${both:}", + }, + "both": "${both:}", + } + ) + assert cfg.both == 20 + assert cfg.b.both == 30 + + +def test_resolver_with_parent_and_default_value(restore_resolvers: Any) -> None: + def parent_and_default(default: int = 10, *, _parent_: Any) -> Any: + return _parent_.add + default + + OmegaConf.register_new_resolver( + "parent_and_default", parent_and_default, use_cache=False + ) + + cfg = OmegaConf.create( + { + "add": 10, + "no_param": "${parent_and_default:}", + "param": "${parent_and_default:20}", + } + ) + + assert cfg.no_param == 20 + assert cfg.param == 30