diff --git a/docs/source/custom_resolvers.rst b/docs/source/custom_resolvers.rst index b1ad14c12..a832516b0 100644 --- a/docs/source/custom_resolvers.rst +++ b/docs/source/custom_resolvers.rst @@ -2,6 +2,9 @@ from omegaconf import OmegaConf, DictConfig import os + import pytest + os.environ['USER'] = 'omry' + def show(x): print(f"type: {type(x).__name__}, value: {repr(x)}") @@ -131,14 +134,14 @@ the inputs themselves: Custom interpolations can also receive the following special parameters: -- ``_parent_`` : the parent node of an interpolation. +- ``_parent_``: the parent node of an interpolation. - ``_root_``: The config root. This can be achieved by adding the special parameters to the resolver signature. -Note that special parameters must be defined as named keywords (after the `*`): +Note that special parameters must be defined as named keywords (after the `*`). -In this example, we use ``_parent_`` to implement a sum function that defaults to 0 if the node does not exist. -(In contrast to the sum we defined earlier where accessing an invalid key, e.g. ``"a_plus_z": ${sum:${a}, ${z}}`` will result in an error). +In the example below, we use ``_parent_`` to implement a sum function that defaults to 0 if the node does not exist. +(In contrast to the sum we defined earlier where accessing an invalid key, e.g. ``"a_plus_z": ${sum:${a}, ${z}}`` would result in an error). .. doctest:: @@ -189,6 +192,53 @@ In such a case, the default value is converted to a string using ``str(default)` The following example falls back to default passwords when ``DB_PASSWORD`` is not defined: +.. doctest:: + + >>> cfg = OmegaConf.create( + ... { + ... "database": { + ... "password1": "${oc.env:DB_PASSWORD,password}", + ... "password2": "${oc.env:DB_PASSWORD,12345}", + ... "password3": "${oc.env:DB_PASSWORD,null}", + ... }, + ... } + ... ) + >>> # default is already a string + >>> show(cfg.database.password1) + type: str, value: 'password' + >>> # default is converted to a string automatically + >>> show(cfg.database.password2) + type: str, value: '12345' + >>> # unless it's None + >>> show(cfg.database.password3) + type: NoneType, value: None + +.. _oc.deprecated: + +oc.deprecated +^^^^^^^^^^^^^ +``oc.deprecated`` enables you to deprecate a config node. +It takes two parameters: + +- ``key``: An interpolation key representing the new key you are migrating to. This parameter is required. +- ``message``: A message to use as the warning when the config node is being accessed. The default message is +``'$OLD_KEY' is deprecated. Change your code and config to use '$NEW_KEY'``. + +.. doctest:: + + >>> conf = OmegaConf.create({ + ... "rusty_key": "${oc.deprecated:shiny_key}", + ... "custom_msg": "${oc.deprecated:shiny_key, 'Use $NEW_KEY'}", + ... "shiny_key": 10 + ... }) + >>> # Accessing rusty_key will issue a deprecation warning + >>> # and return the new value automatically + >>> warning = "'rusty_key' is deprecated. Change your" \ + ... " code and config to use 'shiny_key'" + >>> with pytest.warns(UserWarning, match=warning): + ... assert conf.rusty_key == 10 + >>> with pytest.warns(UserWarning, match="Use shiny_key"): + ... assert conf.custom_msg == 10 .. _oc.decode: diff --git a/docs/source/usage.rst b/docs/source/usage.rst index a02c481fe..054bbe14d 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -5,7 +5,6 @@ import sys import tempfile import pickle - os.environ['USER'] = 'omry' # ensures that DB_TIMEOUT is not set in the doc. os.environ.pop('DB_TIMEOUT', None) @@ -385,9 +384,9 @@ Interpolated nodes can be any node in the config, not just leaf nodes: Resolvers ^^^^^^^^^ -You can add additional interpolation types by registering resolvers using ``OmegaConf.register_new_resolver()``. +Add new interpolation types by registering resolvers using ``OmegaConf.register_new_resolver()``. Such resolvers are called when the config node is accessed. -See :doc:`custom_resolvers` for more details, or keep reading for a minimal example. +The minimal example below shows its most basic usage, see :doc:`custom_resolvers` for more details. .. doctest:: @@ -403,7 +402,8 @@ Built-in resolvers OmegaConf comes with a set of built-in custom resolvers: * :ref:`oc.decode`: Parsing an input string using interpolation grammar -* :ref:`oc.env`: Accessing environment variables. +* :ref:`oc.deprecated`: Deprecate a key in your config +* :ref:`oc.env`: Accessing environment variables * :ref:`oc.dict.{keys,values}`: Viewing the keys or the values of a dictionary as a list diff --git a/news/681.feature b/news/681.feature new file mode 100644 index 000000000..e66dcac5f --- /dev/null +++ b/news/681.feature @@ -0,0 +1 @@ +Introduce oc.deprecated resolver, that enables deprecating config nodes diff --git a/omegaconf/_impl.py b/omegaconf/_impl.py index 3f9deba49..8ae40861c 100644 --- a/omegaconf/_impl.py +++ b/omegaconf/_impl.py @@ -63,7 +63,9 @@ def select_value( throw_on_missing=throw_on_missing, absolute_key=absolute_key, ) + if isinstance(ret, Node) and ret._is_missing(): + assert not throw_on_missing # would have raised an exception in select_node return None return _get_value(ret) @@ -106,6 +108,4 @@ def select_node( ): return default - if value is not None and value._is_missing(): - return None return value diff --git a/omegaconf/base.py b/omegaconf/base.py index 346952e4e..2314681be 100644 --- a/omegaconf/base.py +++ b/omegaconf/base.py @@ -633,7 +633,17 @@ def _evaluate_custom_resolver( resolver = OmegaConf._get_resolver(inter_type) if resolver is not None: root_node = self._get_root() - return resolver(root_node, self, inter_args, inter_args_str) + node = None + if key is not None: + node = self._get_node(key, validate_access=True) + assert node is None or isinstance(node, Node) + return resolver( + root_node, + self, + node, + inter_args, + inter_args_str, + ) else: raise UnsupportedInterpolationType( f"Unsupported interpolation type {inter_type}" diff --git a/omegaconf/omegaconf.py b/omegaconf/omegaconf.py index e804fe022..e20d9c300 100644 --- a/omegaconf/omegaconf.py +++ b/omegaconf/omegaconf.py @@ -94,6 +94,7 @@ def register_default_resolvers() -> None: from omegaconf.resolvers import env, oc OmegaConf.register_new_resolver("oc.decode", oc.decode) + OmegaConf.register_new_resolver("oc.deprecated", oc.deprecated) OmegaConf.register_new_resolver("oc.env", oc.env) OmegaConf.register_new_resolver("oc.dict.keys", oc.dict.keys) OmegaConf.register_new_resolver("oc.dict.values", oc.dict.values) @@ -323,7 +324,8 @@ def legacy_register_resolver(name: str, resolver: Resolver) -> None: def resolver_wrapper( config: BaseContainer, - node: BaseContainer, + parent: BaseContainer, + node: Node, args: Tuple[Any, ...], args_str: Tuple[str, ...], ) -> Any: @@ -401,11 +403,13 @@ def _should_pass(special: str) -> bool: return ret pass_parent = _should_pass("_parent_") + pass_node = _should_pass("_node_") pass_root = _should_pass("_root_") def resolver_wrapper( config: BaseContainer, parent: Container, + node: Node, args: Tuple[Any, ...], args_str: Tuple[str, ...], ) -> Any: @@ -417,9 +421,11 @@ def resolver_wrapper( pass # Call resolver. - kwargs = {} + kwargs: Dict[str, Node] = {} if pass_parent: kwargs["_parent_"] = parent + if pass_node: + kwargs["_node_"] = node if pass_root: kwargs["_root_"] = config @@ -442,7 +448,7 @@ def get_resolver( cls, name: str, ) -> Optional[ - Callable[[Container, Container, Tuple[Any, ...], Tuple[str, ...]], Any] + Callable[[Container, Container, Node, Tuple[Any, ...], Tuple[str, ...]], Any] ]: warnings.warn( "`OmegaConf.get_resolver()` is deprecated (see https://github.com/omry/omegaconf/issues/608)", @@ -877,7 +883,10 @@ def _get_obj_type(c: Any) -> Optional[Type[Any]]: def _get_resolver( name: str, ) -> Optional[ - Callable[[Container, Container, Tuple[Any, ...], Tuple[str, ...]], Any] + Callable[ + [Container, Container, Optional[Node], Tuple[Any, ...], Tuple[str, ...]], + Any, + ] ]: # noinspection PyProtectedMember return ( diff --git a/omegaconf/resolvers/oc/__init__.py b/omegaconf/resolvers/oc/__init__.py index 72b2be2b9..0d4baae4a 100644 --- a/omegaconf/resolvers/oc/__init__.py +++ b/omegaconf/resolvers/oc/__init__.py @@ -1,8 +1,11 @@ import os +import string +import warnings from typing import Any, Optional -from omegaconf import Container +from omegaconf import Container, Node from omegaconf._utils import _DEFAULT_MARKER_, _get_value +from omegaconf.errors import ConfigKeyError from omegaconf.grammar_parser import parse from omegaconf.resolvers.oc import dict @@ -46,8 +49,44 @@ def decode(expr: Optional[str], _parent_: Container) -> Any: return _get_value(val) +def deprecated( + key: str, + message: str = "'$OLD_KEY' is deprecated. Change your code and config to use '$NEW_KEY'", + *, + _parent_: Container, + _node_: Optional[Node], +) -> Any: + from omegaconf._impl import select_node + + if not isinstance(key, str): + raise TypeError( + f"oc.deprecated: interpolation key type is not a string ({type(key).__name__})" + ) + + if not isinstance(message, str): + raise TypeError( + f"oc.deprecated: interpolation message type is not a string ({type(message).__name__})" + ) + + assert _node_ is not None + full_key = _node_._get_full_key(key=None) + target_node = select_node(_parent_, key, absolute_key=True) + if target_node is None: + raise ConfigKeyError( + f"In oc.deprecated resolver at '{full_key}': Key not found: '{key}'" + ) + new_key = target_node._get_full_key(key=None) + msg = string.Template(message).safe_substitute( + OLD_KEY=full_key, + NEW_KEY=new_key, + ) + warnings.warn(category=UserWarning, message=msg) + return target_node + + __all__ = [ "decode", + "deprecated", "dict", "env", ] diff --git a/tests/interpolation/built_in_resolvers/test_oc_deprecated.py b/tests/interpolation/built_in_resolvers/test_oc_deprecated.py new file mode 100644 index 000000000..7f03ab34a --- /dev/null +++ b/tests/interpolation/built_in_resolvers/test_oc_deprecated.py @@ -0,0 +1,107 @@ +import re +from typing import Any + +from pytest import mark, param, raises, warns + +from omegaconf import OmegaConf +from omegaconf.errors import InterpolationResolutionError + + +@mark.parametrize( + ("cfg", "key", "expected_value", "expected_warning"), + [ + param( + {"a": 10, "b": "${oc.deprecated: a}"}, + "b", + 10, + "'b' is deprecated. Change your code and config to use 'a'", + id="value", + ), + param( + {"a": 10, "b": "${oc.deprecated: a, '$OLD_KEY is deprecated'}"}, + "b", + 10, + "b is deprecated", + id="value-custom-message", + ), + param( + { + "a": 10, + "b": "${oc.deprecated: a, ${warning}}", + "warning": "$OLD_KEY is bad, $NEW_KEY is good", + }, + "b", + 10, + "b is bad, a is good", + id="value-custom-message-config-variable", + ), + param( + {"a": {"b": 10}, "b": "${oc.deprecated: a}"}, + "b", + OmegaConf.create({"b": 10}), + "'b' is deprecated. Change your code and config to use 'a'", + id="dict", + ), + param( + {"a": {"b": 10}, "b": "${oc.deprecated: a}"}, + "b.b", + 10, + "'b' is deprecated. Change your code and config to use 'a'", + id="dict_value", + ), + param( + {"a": [0, 1], "b": "${oc.deprecated: a}"}, + "b", + OmegaConf.create([0, 1]), + "'b' is deprecated. Change your code and config to use 'a'", + id="list", + ), + param( + {"a": [0, 1], "b": "${oc.deprecated: a}"}, + "b[1]", + 1, + "'b' is deprecated. Change your code and config to use 'a'", + id="list_value", + ), + ], +) +def test_deprecated( + cfg: Any, key: str, expected_value: Any, expected_warning: str +) -> None: + cfg = OmegaConf.create(cfg) + with warns(UserWarning, match=re.escape(expected_warning)): + value = OmegaConf.select(cfg, key) + assert value == expected_value + assert type(value) == type(expected_value) + + +@mark.parametrize( + ("cfg", "error"), + [ + param( + {"a": "${oc.deprecated: z}"}, + "ConfigKeyError raised while resolving interpolation:" + " In oc.deprecated resolver at 'a': Key not found: 'z'", + id="target_not_found", + ), + param( + {"a": "${oc.deprecated: 111111}"}, + "TypeError raised while resolving interpolation: oc.deprecated:" + " interpolation key type is not a string (int)", + id="invalid_key_type", + ), + param( + {"a": "${oc.deprecated: b, 1000}", "b": 10}, + "TypeError raised while resolving interpolation: oc.deprecated:" + " interpolation message type is not a string (int)", + id="invalid_message_type", + ), + ], +) +def test_deprecated_target_not_found(cfg: Any, error: str) -> None: + cfg = OmegaConf.create(cfg) + with raises( + InterpolationResolutionError, + match=re.escape(error), + ): + cfg.a