Skip to content

Commit

Permalink
oc.deprecated support
Browse files Browse the repository at this point in the history
  • Loading branch information
omry committed Apr 15, 2021
1 parent 0f6ad91 commit 65598ab
Show file tree
Hide file tree
Showing 7 changed files with 200 additions and 9 deletions.
30 changes: 30 additions & 0 deletions docs/source/custom_resolvers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

from omegaconf import OmegaConf, DictConfig
import os
import pytest
os.environ['USER'] = 'omry'

def show(x):
print(f"type: {type(x).__name__}, value: {repr(x)}")

Expand Down Expand Up @@ -190,6 +193,33 @@ In such a case, the default value is converted to a string using ``str(default)`
The following example falls back to default passwords when ``DB_PASSWORD`` is not defined:


.. _oc.deprecated:

oc.deprecated
^^^^^^^^^^^^^
``oc.deprecated`` enables you to deprecate a config node.
It takes two parameter:
- key: An interpolation key representing the new key you are migrating to. This parameter is required.
- message: A message to use as the warning when the config node is being accessed. The default message is
``'$KEY' is deprecated. Change your code and config to use '$NEW_KEY'``. Note that $KEY and $NEW_KEY
do not use interpolation syntax.

.. doctest::

>>> conf = OmegaConf.create({
... "rusty_key" : "${oc.deprecated:shiny_key}",
... "rusty_key_custom_msg" : "${oc.deprecated:shiny_key, 'Why you no use $NEW_KEY?'}",
... "shiny_key": 10
... })
>>> # Accessing rusty_key will issue a deprecation warning
>>> # and return the new value automatically
>>> warning = "'rusty_key' is deprecated. Change your" \
... " code and config to use 'shiny_key'"
>>> with pytest.warns(UserWarning, match=warning):
... assert conf.rusty_key == 10
>>> with pytest.warns(UserWarning, match="Why you no use shiny_key?"):
... assert conf.rusty_key_custom_msg == 10

.. _oc.decode:

oc.decode
Expand Down
2 changes: 1 addition & 1 deletion docs/source/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import sys
import tempfile
import pickle
os.environ['USER'] = 'omry'
# ensures that DB_TIMEOUT is not set in the doc.
os.environ.pop('DB_TIMEOUT', None)

Expand Down Expand Up @@ -403,6 +402,7 @@ Built-in resolvers
OmegaConf comes with a set of built-in custom resolvers:

* :ref:`oc.decode`: Parsing an input string using interpolation grammar
* :ref:`oc.deprecated`: Deprecate a key in your config
* :ref:`oc.env`: Accessing environment variables.
* :ref:`oc.dict.{keys,values}`: Viewing the keys or the values of a dictionary as a list

Expand Down
2 changes: 0 additions & 2 deletions omegaconf/_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,4 @@ def select_node(
):
return default

if value is not None and value._is_missing():
return None
return value
12 changes: 11 additions & 1 deletion omegaconf/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,17 @@ def _evaluate_custom_resolver(
resolver = OmegaConf._get_resolver(inter_type)
if resolver is not None:
root_node = self._get_root()
return resolver(root_node, self, inter_args, inter_args_str)
node = None
if key is not None:
node = self._get_node(key, validate_access=True)
assert node is None or isinstance(node, Node)
return resolver(
root_node,
self,
node,
inter_args,
inter_args_str,
)
else:
raise UnsupportedInterpolationType(
f"Unsupported interpolation type {inter_type}"
Expand Down
17 changes: 13 additions & 4 deletions omegaconf/omegaconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def register_default_resolvers() -> None:
from omegaconf.resolvers import env, oc

OmegaConf.register_new_resolver("oc.decode", oc.decode)
OmegaConf.register_new_resolver("oc.deprecated", oc.deprecated)
OmegaConf.register_new_resolver("oc.env", oc.env)
OmegaConf.register_new_resolver("oc.dict.keys", oc.dict.keys)
OmegaConf.register_new_resolver("oc.dict.values", oc.dict.values)
Expand Down Expand Up @@ -323,7 +324,8 @@ def legacy_register_resolver(name: str, resolver: Resolver) -> None:

def resolver_wrapper(
config: BaseContainer,
node: BaseContainer,
parent: BaseContainer,
node: Node,
args: Tuple[Any, ...],
args_str: Tuple[str, ...],
) -> Any:
Expand Down Expand Up @@ -401,11 +403,13 @@ def _should_pass(special: str) -> bool:
return ret

pass_parent = _should_pass("_parent_")
pass_node = _should_pass("_node_")
pass_root = _should_pass("_root_")

def resolver_wrapper(
config: BaseContainer,
parent: Container,
node: Node,
args: Tuple[Any, ...],
args_str: Tuple[str, ...],
) -> Any:
Expand All @@ -417,9 +421,11 @@ def resolver_wrapper(
pass

# Call resolver.
kwargs = {}
kwargs: Dict[str, Node] = {}
if pass_parent:
kwargs["_parent_"] = parent
if pass_node:
kwargs["_node_"] = node
if pass_root:
kwargs["_root_"] = config

Expand All @@ -442,7 +448,7 @@ def get_resolver(
cls,
name: str,
) -> Optional[
Callable[[Container, Container, Tuple[Any, ...], Tuple[str, ...]], Any]
Callable[[Container, Container, Node, Tuple[Any, ...], Tuple[str, ...]], Any]
]:
warnings.warn(
"`OmegaConf.get_resolver()` is deprecated (see https://github.com/omry/omegaconf/issues/608)",
Expand Down Expand Up @@ -877,7 +883,10 @@ def _get_obj_type(c: Any) -> Optional[Type[Any]]:
def _get_resolver(
name: str,
) -> Optional[
Callable[[Container, Container, Tuple[Any, ...], Tuple[str, ...]], Any]
Callable[
[Container, Container, Optional[Node], Tuple[Any, ...], Tuple[str, ...]],
Any,
]
]:
# noinspection PyProtectedMember
return (
Expand Down
39 changes: 38 additions & 1 deletion omegaconf/resolvers/oc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import os
import string
import warnings
from typing import Any, Optional

from omegaconf import Container
from omegaconf import Container, Node
from omegaconf._utils import _DEFAULT_MARKER_, _get_value
from omegaconf.errors import ConfigKeyError
from omegaconf.grammar_parser import parse
from omegaconf.resolvers.oc import dict

Expand Down Expand Up @@ -46,8 +49,42 @@ def decode(expr: Optional[str], _parent_: Container) -> Any:
return _get_value(val)


def deprecated(
key: str,
message: str = "'$KEY' is deprecated. Change your code and config to use '$NEW_KEY'",
*,
_parent_: Container,
_node_: Optional[Node],
) -> Any:
from omegaconf._impl import select_node

if not isinstance(key, str):
raise ValueError(
f"oc.deprecated: interpolation key type is not a string ({type(key).__name__})"
)
if not isinstance(message, str):
raise ValueError(
f"oc.deprecated: interpolation message type is not a string ({type(message).__name__})"
)
assert _node_ is not None
full_key = _node_._get_full_key(key=None)
target_node = select_node(_parent_, key, absolute_key=True)
if target_node is None:
raise ConfigKeyError(
f"In oc.deprecate resolver at '{full_key}': Key not found: '{key}'"
)
new_key = target_node._get_full_key(key=None)
msg = string.Template(message).safe_substitute(
KEY=full_key,
NEW_KEY=new_key,
)
warnings.warn(category=UserWarning, message=msg)
return target_node


__all__ = [
"decode",
"deprecated",
"dict",
"env",
]
107 changes: 107 additions & 0 deletions tests/interpolation/built_in_resolvers/test_oc_deprecated.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import re
from typing import Any

from pytest import mark, param, raises, warns

from omegaconf import OmegaConf
from omegaconf._utils import _ensure_container
from omegaconf.errors import InterpolationResolutionError


@mark.parametrize(
("cfg", "key", "expected_value", "expected_warning"),
[
param(
{"a": 10, "b": "${oc.deprecated: a}"},
"b",
10,
"'b' is deprecated. Change your code and config to use 'a'",
id="value",
),
param(
{"a": 10, "b": "${oc.deprecated: a, '$KEY is deprecated'}"},
"b",
10,
"b is deprecated",
id="value-custom-message",
),
param(
{
"a": 10,
"b": "${oc.deprecated: a, ${warning}}",
"warning": "$KEY is bad, $NEW_KEY is good",
},
"b",
10,
"b is bad, a is good",
id="value-custom-message-config-variable",
),
param(
{"a": {"b": 10}, "b": "${oc.deprecated: a}"},
"b",
OmegaConf.create({"b": 10}),
"'b' is deprecated. Change your code and config to use 'a'",
id="dict",
),
param(
{"a": {"b": 10}, "b": "${oc.deprecated: a}"},
"b.b",
10,
"'b' is deprecated. Change your code and config to use 'a'",
id="dict_value",
),
param(
{"a": [0, 1], "b": "${oc.deprecated: a}"},
"b",
OmegaConf.create([0, 1]),
"'b' is deprecated. Change your code and config to use 'a'",
id="list",
),
param(
{"a": [0, 1], "b": "${oc.deprecated: a}"},
"b[1]",
1,
"'b' is deprecated. Change your code and config to use 'a'",
id="list_value",
),
],
)
def test_deprecated(
cfg: Any, key: str, expected_value: Any, expected_warning: str
) -> None:
cfg = _ensure_container(cfg)
with warns(UserWarning, match=re.escape(expected_warning)):
value = OmegaConf.select(cfg, key)
assert value == expected_value
assert type(value) == type(expected_value)


@mark.parametrize(
("cfg", "error"),
[
param(
{"a": "${oc.deprecated: z}"},
"ConfigKeyError raised while resolving interpolation: In oc.deprecate resolver at 'a': Key not found: 'z'",
id="target_not_found",
),
param(
{"a": "${oc.deprecated: 111111}"},
"ValueError raised while resolving interpolation: oc.deprecated:"
" interpolation key type is not a string (int)",
id="invalid_key_type",
),
param(
{"a": "${oc.deprecated: b, 1000}", "b": 10},
"ValueError raised while resolving interpolation: oc.deprecated:"
" interpolation message type is not a string (int)",
id="invalid_message_type",
),
],
)
def test_deprecated_target_not_found(cfg: Any, error: str) -> None:
cfg = _ensure_container(cfg)
with raises(
InterpolationResolutionError,
match=re.escape(error),
):
cfg.a

0 comments on commit 65598ab

Please sign in to comment.