Skip to content

Commit

Permalink
oc.deprecated support
Browse files Browse the repository at this point in the history
  • Loading branch information
omry committed Apr 16, 2021
1 parent 0124b51 commit dad564e
Show file tree
Hide file tree
Showing 8 changed files with 232 additions and 16 deletions.
58 changes: 54 additions & 4 deletions docs/source/custom_resolvers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

from omegaconf import OmegaConf, DictConfig
import os
import pytest
os.environ['USER'] = 'omry'

def show(x):
print(f"type: {type(x).__name__}, value: {repr(x)}")

Expand Down Expand Up @@ -131,14 +134,14 @@ the inputs themselves:

Custom interpolations can also receive the following special parameters:

- ``_parent_`` : the parent node of an interpolation.
- ``_parent_``: the parent node of an interpolation.
- ``_root_``: The config root.

This can be achieved by adding the special parameters to the resolver signature.
Note that special parameters must be defined as named keywords (after the `*`):
Note that special parameters must be defined as named keywords (after the `*`).

In this example, we use ``_parent_`` to implement a sum function that defaults to 0 if the node does not exist.
(In contrast to the sum we defined earlier where accessing an invalid key, e.g. ``"a_plus_z": ${sum:${a}, ${z}}`` will result in an error).
In the example below, we use ``_parent_`` to implement a sum function that defaults to 0 if the node does not exist.
(In contrast to the sum we defined earlier where accessing an invalid key, e.g. ``"a_plus_z": ${sum:${a}, ${z}}`` would result in an error).

.. doctest::

Expand Down Expand Up @@ -189,6 +192,53 @@ In such a case, the default value is converted to a string using ``str(default)`

The following example falls back to default passwords when ``DB_PASSWORD`` is not defined:

.. doctest::

>>> cfg = OmegaConf.create(
... {
... "database": {
... "password1": "${oc.env:DB_PASSWORD,password}",
... "password2": "${oc.env:DB_PASSWORD,12345}",
... "password3": "${oc.env:DB_PASSWORD,null}",
... },
... }
... )
>>> # default is already a string
>>> show(cfg.database.password1)
type: str, value: 'password'
>>> # default is converted to a string automatically
>>> show(cfg.database.password2)
type: str, value: '12345'
>>> # unless it's None
>>> show(cfg.database.password3)
type: NoneType, value: None

.. _oc.deprecated:

oc.deprecated
^^^^^^^^^^^^^
``oc.deprecated`` enables you to deprecate a config node.
It takes two parameters:

- ``key``: An interpolation key representing the new key you are migrating to. This parameter is required.
- ``message``: A message to use as the warning when the config node is being accessed. The default message is
``'$OLD_KEY' is deprecated. Change your code and config to use '$NEW_KEY'``.

.. doctest::

>>> conf = OmegaConf.create({
... "rusty_key": "${oc.deprecated:shiny_key}",
... "custom_msg": "${oc.deprecated:shiny_key, 'Use $NEW_KEY'}",
... "shiny_key": 10
... })
>>> # Accessing rusty_key will issue a deprecation warning
>>> # and return the new value automatically
>>> warning = "'rusty_key' is deprecated. Change your" \
... " code and config to use 'shiny_key'"
>>> with pytest.warns(UserWarning, match=warning):
... assert conf.rusty_key == 10
>>> with pytest.warns(UserWarning, match="Use shiny_key"):
... assert conf.custom_msg == 10

.. _oc.decode:

Expand Down
8 changes: 4 additions & 4 deletions docs/source/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import sys
import tempfile
import pickle
os.environ['USER'] = 'omry'
# ensures that DB_TIMEOUT is not set in the doc.
os.environ.pop('DB_TIMEOUT', None)

Expand Down Expand Up @@ -385,9 +384,9 @@ Interpolated nodes can be any node in the config, not just leaf nodes:

Resolvers
^^^^^^^^^
You can add additional interpolation types by registering resolvers using ``OmegaConf.register_new_resolver()``.
Add new interpolation types by registering resolvers using ``OmegaConf.register_new_resolver()``.
Such resolvers are called when the config node is accessed.
See :doc:`custom_resolvers` for more details, or keep reading for a minimal example.
The minimal example below shows its most basic usage, see :doc:`custom_resolvers` for more details.

.. doctest::

Expand All @@ -403,7 +402,8 @@ Built-in resolvers
OmegaConf comes with a set of built-in custom resolvers:

* :ref:`oc.decode`: Parsing an input string using interpolation grammar
* :ref:`oc.env`: Accessing environment variables.
* :ref:`oc.deprecated`: Deprecate a key in your config
* :ref:`oc.env`: Accessing environment variables
* :ref:`oc.dict.{keys,values}`: Viewing the keys or the values of a dictionary as a list


Expand Down
1 change: 1 addition & 0 deletions news/681.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Introduce oc.deprecated resolver, that enables deprecating config nodes
4 changes: 2 additions & 2 deletions omegaconf/_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ def select_value(
throw_on_missing=throw_on_missing,
absolute_key=absolute_key,
)

if isinstance(ret, Node) and ret._is_missing():
assert not throw_on_missing # would have raised an exception in select_node
return None

return _get_value(ret)
Expand Down Expand Up @@ -106,6 +108,4 @@ def select_node(
):
return default

if value is not None and value._is_missing():
return None
return value
12 changes: 11 additions & 1 deletion omegaconf/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,17 @@ def _evaluate_custom_resolver(
resolver = OmegaConf._get_resolver(inter_type)
if resolver is not None:
root_node = self._get_root()
return resolver(root_node, self, inter_args, inter_args_str)
node = None
if key is not None:
node = self._get_node(key, validate_access=True)
assert node is None or isinstance(node, Node)
return resolver(
root_node,
self,
node,
inter_args,
inter_args_str,
)
else:
raise UnsupportedInterpolationType(
f"Unsupported interpolation type {inter_type}"
Expand Down
17 changes: 13 additions & 4 deletions omegaconf/omegaconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def register_default_resolvers() -> None:
from omegaconf.resolvers import env, oc

OmegaConf.register_new_resolver("oc.decode", oc.decode)
OmegaConf.register_new_resolver("oc.deprecated", oc.deprecated)
OmegaConf.register_new_resolver("oc.env", oc.env)
OmegaConf.register_new_resolver("oc.dict.keys", oc.dict.keys)
OmegaConf.register_new_resolver("oc.dict.values", oc.dict.values)
Expand Down Expand Up @@ -324,7 +325,8 @@ def legacy_register_resolver(name: str, resolver: Resolver) -> None:

def resolver_wrapper(
config: BaseContainer,
node: BaseContainer,
parent: BaseContainer,
node: Node,
args: Tuple[Any, ...],
args_str: Tuple[str, ...],
) -> Any:
Expand Down Expand Up @@ -402,11 +404,13 @@ def _should_pass(special: str) -> bool:
return ret

pass_parent = _should_pass("_parent_")
pass_node = _should_pass("_node_")
pass_root = _should_pass("_root_")

def resolver_wrapper(
config: BaseContainer,
parent: Container,
node: Node,
args: Tuple[Any, ...],
args_str: Tuple[str, ...],
) -> Any:
Expand All @@ -418,9 +422,11 @@ def resolver_wrapper(
pass

# Call resolver.
kwargs = {}
kwargs: Dict[str, Node] = {}
if pass_parent:
kwargs["_parent_"] = parent
if pass_node:
kwargs["_node_"] = node
if pass_root:
kwargs["_root_"] = config

Expand All @@ -443,7 +449,7 @@ def get_resolver(
cls,
name: str,
) -> Optional[
Callable[[Container, Container, Tuple[Any, ...], Tuple[str, ...]], Any]
Callable[[Container, Container, Node, Tuple[Any, ...], Tuple[str, ...]], Any]
]:
warnings.warn(
"`OmegaConf.get_resolver()` is deprecated (see https://github.com/omry/omegaconf/issues/608)",
Expand Down Expand Up @@ -878,7 +884,10 @@ def _get_obj_type(c: Any) -> Optional[Type[Any]]:
def _get_resolver(
name: str,
) -> Optional[
Callable[[Container, Container, Tuple[Any, ...], Tuple[str, ...]], Any]
Callable[
[Container, Container, Optional[Node], Tuple[Any, ...], Tuple[str, ...]],
Any,
]
]:
# noinspection PyProtectedMember
return (
Expand Down
41 changes: 40 additions & 1 deletion omegaconf/resolvers/oc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import os
import string
import warnings
from typing import Any, Optional

from omegaconf import Container
from omegaconf import Container, Node
from omegaconf._utils import _DEFAULT_MARKER_, _get_value
from omegaconf.errors import ConfigKeyError
from omegaconf.grammar_parser import parse
from omegaconf.resolvers.oc import dict

Expand Down Expand Up @@ -46,8 +49,44 @@ def decode(expr: Optional[str], _parent_: Container) -> Any:
return _get_value(val)


def deprecated(
key: str,
message: str = "'$OLD_KEY' is deprecated. Change your code and config to use '$NEW_KEY'",
*,
_parent_: Container,
_node_: Optional[Node],
) -> Any:
from omegaconf._impl import select_node

if not isinstance(key, str):
raise TypeError(
f"oc.deprecated: interpolation key type is not a string ({type(key).__name__})"
)

if not isinstance(message, str):
raise TypeError(
f"oc.deprecated: interpolation message type is not a string ({type(message).__name__})"
)

assert _node_ is not None
full_key = _node_._get_full_key(key=None)
target_node = select_node(_parent_, key, absolute_key=True)
if target_node is None:
raise ConfigKeyError(
f"In oc.deprecated resolver at '{full_key}': Key not found: '{key}'"
)
new_key = target_node._get_full_key(key=None)
msg = string.Template(message).safe_substitute(
OLD_KEY=full_key,
NEW_KEY=new_key,
)
warnings.warn(category=UserWarning, message=msg)
return target_node


__all__ = [
"decode",
"deprecated",
"dict",
"env",
]
107 changes: 107 additions & 0 deletions tests/interpolation/built_in_resolvers/test_oc_deprecated.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import re
from typing import Any

from pytest import mark, param, raises, warns

from omegaconf import OmegaConf
from omegaconf.errors import InterpolationResolutionError


@mark.parametrize(
("cfg", "key", "expected_value", "expected_warning"),
[
param(
{"a": 10, "b": "${oc.deprecated: a}"},
"b",
10,
"'b' is deprecated. Change your code and config to use 'a'",
id="value",
),
param(
{"a": 10, "b": "${oc.deprecated: a, '$OLD_KEY is deprecated'}"},
"b",
10,
"b is deprecated",
id="value-custom-message",
),
param(
{
"a": 10,
"b": "${oc.deprecated: a, ${warning}}",
"warning": "$OLD_KEY is bad, $NEW_KEY is good",
},
"b",
10,
"b is bad, a is good",
id="value-custom-message-config-variable",
),
param(
{"a": {"b": 10}, "b": "${oc.deprecated: a}"},
"b",
OmegaConf.create({"b": 10}),
"'b' is deprecated. Change your code and config to use 'a'",
id="dict",
),
param(
{"a": {"b": 10}, "b": "${oc.deprecated: a}"},
"b.b",
10,
"'b' is deprecated. Change your code and config to use 'a'",
id="dict_value",
),
param(
{"a": [0, 1], "b": "${oc.deprecated: a}"},
"b",
OmegaConf.create([0, 1]),
"'b' is deprecated. Change your code and config to use 'a'",
id="list",
),
param(
{"a": [0, 1], "b": "${oc.deprecated: a}"},
"b[1]",
1,
"'b' is deprecated. Change your code and config to use 'a'",
id="list_value",
),
],
)
def test_deprecated(
cfg: Any, key: str, expected_value: Any, expected_warning: str
) -> None:
cfg = OmegaConf.create(cfg)
with warns(UserWarning, match=re.escape(expected_warning)):
value = OmegaConf.select(cfg, key)
assert value == expected_value
assert type(value) == type(expected_value)


@mark.parametrize(
("cfg", "error"),
[
param(
{"a": "${oc.deprecated: z}"},
"ConfigKeyError raised while resolving interpolation:"
" In oc.deprecated resolver at 'a': Key not found: 'z'",
id="target_not_found",
),
param(
{"a": "${oc.deprecated: 111111}"},
"TypeError raised while resolving interpolation: oc.deprecated:"
" interpolation key type is not a string (int)",
id="invalid_key_type",
),
param(
{"a": "${oc.deprecated: b, 1000}", "b": 10},
"TypeError raised while resolving interpolation: oc.deprecated:"
" interpolation message type is not a string (int)",
id="invalid_message_type",
),
],
)
def test_deprecated_target_not_found(cfg: Any, error: str) -> None:
cfg = OmegaConf.create(cfg)
with raises(
InterpolationResolutionError,
match=re.escape(error),
):
cfg.a

0 comments on commit dad564e

Please sign in to comment.