From daf211a32876ac1520329f1a7431d93e35c92a32 Mon Sep 17 00:00:00 2001 From: Omry Yadan Date: Wed, 14 Apr 2021 18:38:35 -0700 Subject: [PATCH 1/3] moved custom resolvers docs into a dedicated page --- docs/source/conf.py | 2 +- docs/source/custom_resolvers.rst | 262 +++++++++++++++++++++++++++++ docs/source/index.rst | 2 + docs/source/usage.rst | 274 ++----------------------------- 4 files changed, 280 insertions(+), 260 deletions(-) create mode 100644 docs/source/custom_resolvers.rst diff --git a/docs/source/conf.py b/docs/source/conf.py index c570079bb..ba502ea08 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -158,7 +158,7 @@ "OmegaConf Documentation", author, "OmegaConf", - "Flexible python configuration system.", + "Flexible Python configuration system. The last one you will ever need.", "Miscellaneous", ) ] diff --git a/docs/source/custom_resolvers.rst b/docs/source/custom_resolvers.rst new file mode 100644 index 000000000..b1ad14c12 --- /dev/null +++ b/docs/source/custom_resolvers.rst @@ -0,0 +1,262 @@ +.. testsetup:: * + + from omegaconf import OmegaConf, DictConfig + import os + def show(x): + print(f"type: {type(x).__name__}, value: {repr(x)}") + +.. _custom_resolvers: + +Custom resolvers +---------------- + +You can add additional interpolation types by registering custom resolvers with ``OmegaConf.register_new_resolver()``: + +.. code-block:: python + + def register_new_resolver( + name: str, + resolver: Resolver, + *, + replace: bool = False, + use_cache: bool = False, + ) -> None + +Attempting to register the same resolver twice will raise a ``ValueError`` unless using ``replace=True``. + +The example below creates a resolver that adds 10 to the given value. + +.. doctest:: + + >>> OmegaConf.register_new_resolver("plus_10", lambda x: x + 10) + >>> c = OmegaConf.create({'key': '${plus_10:990}'}) + >>> c.key + 1000 + +Custom resolvers support variadic argument lists in the form of a comma separated list of zero or more values. +Whitespaces are stripped from both ends of each value ("foo,bar" is the same as "foo, bar "). +You can use literal commas and spaces anywhere by escaping (:code:`\,` and :code:`\ `), or +simply use quotes to bypass character limitations in strings. + +.. doctest:: + + >>> OmegaConf.register_new_resolver("concat", lambda x, y: x+y) + >>> c = OmegaConf.create({ + ... 'key1': '${concat:Hello,World}', + ... 'key_trimmed': '${concat:Hello , World}', + ... 'escape_whitespace': '${concat:Hello,\ World}', + ... 'quoted': '${concat:"Hello,", " World"}', + ... }) + >>> c.key1 + 'HelloWorld' + >>> c.key_trimmed + 'HelloWorld' + >>> c.escape_whitespace + 'Hello World' + >>> c.quoted + 'Hello, World' + + +Custom resolvers can return lists or dictionaries, that are automatically converted into DictConfig and ListConfig: + +.. doctest:: + + >>> OmegaConf.register_new_resolver( + ... "min_max", lambda *a: {"min": min(a), "max": max(a)} + ... ) + >>> c = OmegaConf.create({'stats': '${min_max: -1, 3, 2, 5, -10}'}) + >>> assert isinstance(c.stats, DictConfig) + >>> c.stats.min, c.stats.max + (-10, 5) + + +You can take advantage of nested interpolations to perform custom operations over variables: + +.. doctest:: + + >>> OmegaConf.register_new_resolver("sum", lambda x, y: x + y) + >>> c = OmegaConf.create({"a": 1, + ... "b": 2, + ... "a_plus_b": "${sum:${a},${b}}"}) + >>> c.a_plus_b + 3 + +More advanced resolver naming features include the ability to prefix a resolver name with a +namespace, and to use interpolations in the name itself. The following example demonstrates both: + +.. doctest:: + + >>> OmegaConf.register_new_resolver("mylib.plus1", lambda x: x + 1) + >>> c = OmegaConf.create( + ... { + ... "func": "plus1", + ... "x": "${mylib.${func}:3}", + ... } + ... ) + >>> c.x + 4 + + +By default a custom resolver is called on every access, but it is possible to cache its output +by registering it with ``use_cache=True``. +This may be useful either for performance reasons or to ensure the same value is always returned. +Note that the cache is based on the string literals representing the resolver's inputs, and not +the inputs themselves: + +.. doctest:: + + >>> import random + >>> random.seed(1234) + >>> OmegaConf.register_new_resolver( + ... "cached", random.randint, use_cache=True + ... ) + >>> OmegaConf.register_new_resolver("uncached", random.randint) + >>> c = OmegaConf.create( + ... { + ... "uncached": "${uncached:0,10000}", + ... "cached_1": "${cached:0,10000}", + ... "cached_2": "${cached:0, 10000}", + ... "cached_3": "${cached:0,${uncached}}", + ... } + ... ) + >>> # not the same since the cache is disabled by default + >>> assert c.uncached != c.uncached + >>> # same value on repeated access thanks to the cache + >>> assert c.cached_1 == c.cached_1 == 122 + >>> # same input as `cached_1` => same value + >>> assert c.cached_2 == c.cached_1 == 122 + >>> # same string literal "${uncached}" => same value + >>> assert c.cached_3 == c.cached_3 == 1192 + + +Custom interpolations can also receive the following special parameters: + +- ``_parent_`` : the parent node of an interpolation. +- ``_root_``: The config root. + +This can be achieved by adding the special parameters to the resolver signature. +Note that special parameters must be defined as named keywords (after the `*`): + +In this example, we use ``_parent_`` to implement a sum function that defaults to 0 if the node does not exist. +(In contrast to the sum we defined earlier where accessing an invalid key, e.g. ``"a_plus_z": ${sum:${a}, ${z}}`` will result in an error). + +.. doctest:: + + >>> def sum2(a, b, *, _parent_): + ... return _parent_.get(a, 0) + _parent_.get(b, 0) + >>> OmegaConf.register_new_resolver("sum2", sum2, use_cache=False) + >>> cfg = OmegaConf.create( + ... { + ... "node": { + ... "a": 1, + ... "b": 2, + ... "a_plus_b": "${sum2:a,b}", + ... "a_plus_z": "${sum2:a,z}", + ... }, + ... } + ... ) + >>> cfg.node.a_plus_b + 3 + >>> cfg.node.a_plus_z + 1 + + +Built-in resolvers +------------------ + +.. _oc.env: + +oc.env +^^^^^^ + +Access to environment variables is supported using ``oc.env``: + +Input YAML file: + +.. include:: env_interpolation.yaml + :code: yaml + +.. doctest:: + + >>> conf = OmegaConf.load('source/env_interpolation.yaml') + >>> conf.user.name + 'omry' + >>> conf.user.home + '/home/omry' + +You can specify a default value to use in case the environment variable is not set. +In such a case, the default value is converted to a string using ``str(default)``, unless it is ``null`` (representing Python ``None``) - in which case ``None`` is returned. + +The following example falls back to default passwords when ``DB_PASSWORD`` is not defined: + + +.. _oc.decode: + +oc.decode +^^^^^^^^^ + +Strings may be converted using ``oc.decode``: + +- Primitive values (e.g., ``"true"``, ``"1"``, ``"1e-3"``) are automatically converted to their corresponding type (bool, int, float) +- Dictionaries and lists (e.g., ``"{a: b}"``, ``"[a, b, c]"``) are returned as transient config nodes (DictConfig and ListConfig) +- Interpolations (e.g., ``"${foo}"``) are automatically resolved +- ``None`` is the only valid non-string input to ``oc.decode`` (returning ``None`` in that case) + +This can be useful for instance to parse environment variables: + +.. doctest:: + + >>> cfg = OmegaConf.create( + ... { + ... "database": { + ... "port": '${oc.decode:${oc.env:DB_PORT}}', + ... "nodes": '${oc.decode:${oc.env:DB_NODES}}', + ... "timeout": '${oc.decode:${oc.env:DB_TIMEOUT,null}}', + ... } + ... } + ... ) + >>> os.environ["DB_PORT"] = "3308" + >>> show(cfg.database.port) # converted to int + type: int, value: 3308 + >>> os.environ["DB_NODES"] = "[host1, host2, host3]" + >>> show(cfg.database.nodes) # converted to a ListConfig + type: ListConfig, value: ['host1', 'host2', 'host3'] + >>> show(cfg.database.timeout) # keeping `None` as is + type: NoneType, value: None + >>> os.environ["DB_TIMEOUT"] = "${.port}" + >>> show(cfg.database.timeout) # resolving interpolation + type: int, value: 3308 + + +.. _oc.dict.{keys,values}: + +oc.dict.{keys,value} +^^^^^^^^^^^^^^^^^^^^ + +Some config options that are stored as a ``DictConfig`` may sometimes be easier to manipulate as lists, +when we care only about the keys or the associated values. + +The resolvers ``oc.dict.keys`` and ``oc.dict.values`` simplify such operations by offering an alternative +view of a dictionary's keys or values as a list. +They take as input a string that is the path to another config node (using the same syntax +as interpolations) and return a ``ListConfig`` with its keys / values. + +.. doctest:: + + >>> cfg = OmegaConf.create( + ... { + ... "workers": { + ... "node3": "10.0.0.2", + ... "node7": "10.0.0.9", + ... }, + ... "nodes": "${oc.dict.keys: workers}", + ... "ips": "${oc.dict.values: workers}", + ... } + ... ) + >>> # Keys are copied from the DictConfig: + >>> show(cfg.nodes) + type: ListConfig, value: ['node3', 'node7'] + >>> # Values are dynamically fetched through interpolations: + >>> show(cfg.ips) + type: ListConfig, value: ['${workers.node3}', '${workers.node7}'] + >>> assert cfg.ips == ["10.0.0.2", "10.0.0.9"] diff --git a/docs/source/index.rst b/docs/source/index.rst index 95796d63c..969121f80 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -11,9 +11,11 @@ OmegaConf also offers runtime type safety via Structured Configs. :maxdepth: 2 usage + custom_resolvers structured_config + Indices and tables ================== diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 418206071..a02c481fe 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -383,272 +383,28 @@ Interpolated nodes can be any node in the config, not just leaf nodes: >>> (cfg.player.height, cfg.player.weight) (180, 75) - -Environment variable interpolation -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Access to environment variables is supported using ``oc.env``: - -Input YAML file: - -.. include:: env_interpolation.yaml - :code: yaml - -.. doctest:: - - >>> conf = OmegaConf.load('source/env_interpolation.yaml') - >>> conf.user.name - 'omry' - >>> conf.user.home - '/home/omry' - -You can specify a default value to use in case the environment variable is not set. -In such a case, the default value is converted to a string using ``str(default)``, unless it is ``null`` (representing Python ``None``) - in which case ``None`` is returned. - -The following example falls back to default passwords when ``DB_PASSWORD`` is not defined: - -.. doctest:: - - >>> cfg = OmegaConf.create( - ... { - ... "database": { - ... "password1": "${oc.env:DB_PASSWORD,password}", - ... "password2": "${oc.env:DB_PASSWORD,12345}", - ... "password3": "${oc.env:DB_PASSWORD,null}", - ... }, - ... } - ... ) - >>> # default is already a string - >>> show(cfg.database.password1) - type: str, value: 'password' - >>> # default is converted to a string automatically - >>> show(cfg.database.password2) - type: str, value: '12345' - >>> # unless it's None - >>> show(cfg.database.password3) - type: NoneType, value: None - - -Decoding strings with interpolations -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Strings may be converted using ``oc.decode``: - -- Primitive values (e.g., ``"true"``, ``"1"``, ``"1e-3"``) are automatically converted to their corresponding type (bool, int, float) -- Dictionaries and lists (e.g., ``"{a: b}"``, ``"[a, b, c]"``) are returned as transient config nodes (DictConfig and ListConfig) -- Interpolations (e.g., ``"${foo}"``) are automatically resolved -- ``None`` is the only valid non-string input to ``oc.decode`` (returning ``None`` in that case) - -This can be useful for instance to parse environment variables: - -.. doctest:: - - >>> cfg = OmegaConf.create( - ... { - ... "database": { - ... "port": '${oc.decode:${oc.env:DB_PORT}}', - ... "nodes": '${oc.decode:${oc.env:DB_NODES}}', - ... "timeout": '${oc.decode:${oc.env:DB_TIMEOUT,null}}', - ... } - ... } - ... ) - >>> os.environ["DB_PORT"] = "3308" - >>> show(cfg.database.port) # converted to int - type: int, value: 3308 - >>> os.environ["DB_NODES"] = "[host1, host2, host3]" - >>> show(cfg.database.nodes) # converted to a ListConfig - type: ListConfig, value: ['host1', 'host2', 'host3'] - >>> show(cfg.database.timeout) # keeping `None` as is - type: NoneType, value: None - >>> os.environ["DB_TIMEOUT"] = "${.port}" - >>> show(cfg.database.timeout) # resolving interpolation - type: int, value: 3308 - - -Extracting lists of keys / values from a dictionary -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Some config options that are stored as a ``DictConfig`` may sometimes be easier to manipulate as lists, -when we care only about the keys or the associated values. - -The resolvers ``oc.dict.keys`` and ``oc.dict.values`` simplify such operations by offering an alternative -view of a dictionary's keys or values as a list. -They take as input a string that is the path to another config node (using the same syntax -as interpolations) and return a ``ListConfig`` with its keys / values. - -.. doctest:: - - >>> cfg = OmegaConf.create( - ... { - ... "workers": { - ... "node3": "10.0.0.2", - ... "node7": "10.0.0.9", - ... }, - ... "nodes": "${oc.dict.keys: workers}", - ... "ips": "${oc.dict.values: workers}", - ... } - ... ) - >>> # Keys are copied from the DictConfig: - >>> show(cfg.nodes) - type: ListConfig, value: ['node3', 'node7'] - >>> # Values are dynamically fetched through interpolations: - >>> show(cfg.ips) - type: ListConfig, value: ['${workers.node3}', '${workers.node7}'] - >>> assert cfg.ips == ["10.0.0.2", "10.0.0.9"] - - -Custom interpolations -^^^^^^^^^^^^^^^^^^^^^ - -You can add additional interpolation types by registering custom resolvers with ``OmegaConf.register_new_resolver()``: - -.. code-block:: python - - def register_new_resolver( - name: str, - resolver: Resolver, - *, - replace: bool = False, - use_cache: bool = False, - ) -> None - -Attempting to register the same resolver twice will raise a ``ValueError`` unless using ``replace=True``. - -The example below creates a resolver that adds 10 to the given value. - -.. doctest:: - - >>> OmegaConf.register_new_resolver("plus_10", lambda x: x + 10) - >>> c = OmegaConf.create({'key': '${plus_10:990}'}) - >>> c.key - 1000 - -Custom resolvers support variadic argument lists in the form of a comma separated list of zero or more values. -Whitespaces are stripped from both ends of each value ("foo,bar" is the same as "foo, bar "). -You can use literal commas and spaces anywhere by escaping (:code:`\,` and :code:`\ `), or -simply use quotes to bypass character limitations in strings. - -.. doctest:: - - >>> OmegaConf.register_new_resolver("concat", lambda x, y: x+y) - >>> c = OmegaConf.create({ - ... 'key1': '${concat:Hello,World}', - ... 'key_trimmed': '${concat:Hello , World}', - ... 'escape_whitespace': '${concat:Hello,\ World}', - ... 'quoted': '${concat:"Hello,", " World"}', - ... }) - >>> c.key1 - 'HelloWorld' - >>> c.key_trimmed - 'HelloWorld' - >>> c.escape_whitespace - 'Hello World' - >>> c.quoted - 'Hello, World' - - -Custom resolvers can return lists or dictionaries, that are automatically converted into DictConfig and ListConfig: +Resolvers +^^^^^^^^^ +You can add additional interpolation types by registering resolvers using ``OmegaConf.register_new_resolver()``. +Such resolvers are called when the config node is accessed. +See :doc:`custom_resolvers` for more details, or keep reading for a minimal example. .. doctest:: >>> OmegaConf.register_new_resolver( - ... "min_max", lambda *a: {"min": min(a), "max": max(a)} + ... "add", lambda *numbers: sum(numbers) ... ) - >>> c = OmegaConf.create({'stats': '${min_max: -1, 3, 2, 5, -10}'}) - >>> assert isinstance(c.stats, DictConfig) - >>> c.stats.min, c.stats.max - (-10, 5) + >>> c = OmegaConf.create({'total': '${add:1,2,3}'}) + >>> c.total + 6 +Built-in resolvers +^^^^^^^^^^^^^^^^^^ +OmegaConf comes with a set of built-in custom resolvers: -You can take advantage of nested interpolations to perform custom operations over variables: - -.. doctest:: - - >>> OmegaConf.register_new_resolver("sum", lambda x, y: x + y) - >>> c = OmegaConf.create({"a": 1, - ... "b": 2, - ... "a_plus_b": "${sum:${a},${b}}"}) - >>> c.a_plus_b - 3 - -More advanced resolver naming features include the ability to prefix a resolver name with a -namespace, and to use interpolations in the name itself. The following example demonstrates both: - -.. doctest:: - - >>> OmegaConf.register_new_resolver("mylib.plus1", lambda x: x + 1) - >>> c = OmegaConf.create( - ... { - ... "func": "plus1", - ... "x": "${mylib.${func}:3}", - ... } - ... ) - >>> c.x - 4 - - -By default a custom resolver is called on every access, but it is possible to cache its output -by registering it with ``use_cache=True``. -This may be useful either for performance reasons or to ensure the same value is always returned. -Note that the cache is based on the string literals representing the resolver's inputs, and not -the inputs themselves: - -.. doctest:: - - >>> import random - >>> random.seed(1234) - >>> OmegaConf.register_new_resolver( - ... "cached", random.randint, use_cache=True - ... ) - >>> OmegaConf.register_new_resolver("uncached", random.randint) - >>> c = OmegaConf.create( - ... { - ... "uncached": "${uncached:0,10000}", - ... "cached_1": "${cached:0,10000}", - ... "cached_2": "${cached:0, 10000}", - ... "cached_3": "${cached:0,${uncached}}", - ... } - ... ) - >>> # not the same since the cache is disabled by default - >>> assert c.uncached != c.uncached - >>> # same value on repeated access thanks to the cache - >>> assert c.cached_1 == c.cached_1 == 122 - >>> # same input as `cached_1` => same value - >>> assert c.cached_2 == c.cached_1 == 122 - >>> # same string literal "${uncached}" => same value - >>> assert c.cached_3 == c.cached_3 == 1192 - - -Custom interpolations can also receive the following special parameters: - -- ``_parent_`` : the parent node of an interpolation. -- ``_root_``: The config root. - -This can be achieved by adding the special parameters to the resolver signature. -Note that special parameters must be defined as named keywords (after the `*`): - -In this example, we use ``_parent_`` to implement a sum function that defaults to 0 if the node does not exist. -(In contrast to the sum we defined earlier where accessing an invalid key, e.g. ``"a_plus_z": ${sum:${a}, ${z}}`` will result in an error). - -.. doctest:: - - >>> def sum2(a, b, *, _parent_): - ... return _parent_.get(a, 0) + _parent_.get(b, 0) - >>> OmegaConf.register_new_resolver("sum2", sum2, use_cache=False) - >>> cfg = OmegaConf.create( - ... { - ... "node": { - ... "a": 1, - ... "b": 2, - ... "a_plus_b": "${sum2:a,b}", - ... "a_plus_z": "${sum2:a,z}", - ... }, - ... } - ... ) - >>> cfg.node.a_plus_b - 3 - >>> cfg.node.a_plus_z - 1 +* :ref:`oc.decode`: Parsing an input string using interpolation grammar +* :ref:`oc.env`: Accessing environment variables. +* :ref:`oc.dict.{keys,values}`: Viewing the keys or the values of a dictionary as a list Merging configurations From 3cf40ec6fee09fead61a4f853f8120a700810c23 Mon Sep 17 00:00:00 2001 From: Omry Yadan Date: Sun, 11 Apr 2021 18:03:37 -0700 Subject: [PATCH 2/3] refactor select_value and select_node out of OmegaConf.select --- omegaconf/_impl.py | 69 +++++++++++++++++++++++++++++++++- omegaconf/omegaconf.py | 46 +++++------------------ omegaconf/resolvers/oc/dict.py | 4 +- tests/test_select.py | 9 +++-- 4 files changed, 84 insertions(+), 44 deletions(-) diff --git a/omegaconf/_impl.py b/omegaconf/_impl.py index 61207e0cc..3f9deba49 100644 --- a/omegaconf/_impl.py +++ b/omegaconf/_impl.py @@ -1,7 +1,9 @@ from typing import Any from omegaconf import MISSING, Container, DictConfig, ListConfig, Node, ValueNode -from omegaconf.errors import InterpolationToMissingValueError +from omegaconf.errors import ConfigKeyError, InterpolationToMissingValueError + +from ._utils import _DEFAULT_MARKER_, _get_value def _resolve_container_value(cfg: Container, key: Any) -> None: @@ -42,3 +44,68 @@ def _resolve(cfg: Node) -> Node: _resolve_container_value(cfg, i) return cfg + + +def select_value( + cfg: Container, + key: str, + *, + default: Any = _DEFAULT_MARKER_, + throw_on_resolution_failure: bool = True, + throw_on_missing: bool = False, + absolute_key: bool = False, +) -> Any: + ret = select_node( + cfg=cfg, + key=key, + default=default, + throw_on_resolution_failure=throw_on_resolution_failure, + throw_on_missing=throw_on_missing, + absolute_key=absolute_key, + ) + if isinstance(ret, Node) and ret._is_missing(): + return None + + return _get_value(ret) + + +def select_node( + cfg: Container, + key: str, + *, + default: Any = _DEFAULT_MARKER_, + throw_on_resolution_failure: bool = True, + throw_on_missing: bool = False, + absolute_key: bool = False, +) -> Any: + try: + # for non relative keys, the interpretation can be: + # 1. relative to cfg + # 2. relative to the config root + # This is controlled by the absolute_key flag. By default, such keys are relative to cfg. + if not absolute_key and not key.startswith("."): + key = f".{key}" + + cfg, key = cfg._resolve_key_and_root(key) + _root, _last_key, value = cfg._select_impl( + key, + throw_on_missing=throw_on_missing, + throw_on_resolution_failure=throw_on_resolution_failure, + ) + except ConfigKeyError: + if default is not _DEFAULT_MARKER_: + return default + else: + raise + + if ( + default is not _DEFAULT_MARKER_ + and _root is not None + and _last_key is not None + and _last_key not in _root + ): + return default + + if value is not None and value._is_missing(): + return None + return value diff --git a/omegaconf/omegaconf.py b/omegaconf/omegaconf.py index 621cdbf97..b08ecb9fd 100644 --- a/omegaconf/omegaconf.py +++ b/omegaconf/omegaconf.py @@ -30,7 +30,6 @@ from ._utils import ( _DEFAULT_MARKER_, _ensure_container, - _get_value, _is_none, format_and_raise, get_dict_key_value_types, @@ -54,7 +53,6 @@ from .base import Container, Node, SCMode from .basecontainer import BaseContainer from .errors import ( - ConfigKeyError, MissingMandatoryValue, OmegaConfBaseException, UnsupportedInterpolationType, @@ -654,7 +652,6 @@ def select( default: Any = _DEFAULT_MARKER_, throw_on_resolution_failure: bool = True, throw_on_missing: bool = False, - absolute_key: bool = False, ) -> Any: """ :param cfg: Config node to select from @@ -664,43 +661,18 @@ def select( resolution error occurs, otherwise return None :param throw_on_missing: Raise an exception if an attempt to select a missing key (with the value '???') is made, otherwise return None - :param absolute_key: True to treat non-relative keys as relative to the config root - False (default) to treat non-relative keys as relative to cfg :return: selected value or None if not found. """ - try: - try: - # for non relative keys, the interpretation can be: - # 1. relative to cfg - # 2. relative to the config root - # This is controlled by the absolute_key flag. By default, such keys are relative to cfg. - if not absolute_key and not key.startswith("."): - key = f".{key}" - - cfg, key = cfg._resolve_key_and_root(key) - _root, _last_key, value = cfg._select_impl( - key, - throw_on_missing=throw_on_missing, - throw_on_resolution_failure=throw_on_resolution_failure, - ) - except ConfigKeyError: - if default is not _DEFAULT_MARKER_: - return default - else: - raise - - if ( - default is not _DEFAULT_MARKER_ - and _root is not None - and _last_key is not None - and _last_key not in _root - ): - return default + from ._impl import select_value - if value is not None and value._is_missing(): - return None - - return _get_value(value) + try: + return select_value( + cfg=cfg, + key=key, + default=default, + throw_on_resolution_failure=throw_on_resolution_failure, + throw_on_missing=throw_on_missing, + ) except Exception as e: format_and_raise(node=cfg, key=key, value=None, cause=e, msg=str(e)) diff --git a/omegaconf/resolvers/oc/dict.py b/omegaconf/resolvers/oc/dict.py index 927082e2f..276e79a30 100644 --- a/omegaconf/resolvers/oc/dict.py +++ b/omegaconf/resolvers/oc/dict.py @@ -53,7 +53,7 @@ def _get_and_validate_dict_input( parent: BaseContainer, resolver_name: str, ) -> DictConfig: - from omegaconf import OmegaConf + from omegaconf._impl import select_value if not isinstance(key, str): raise TypeError( @@ -61,7 +61,7 @@ def _get_and_validate_dict_input( f"of type: {type(key).__name__}" ) - in_dict = OmegaConf.select( + in_dict = select_value( parent, key, throw_on_missing=True, diff --git a/tests/test_select.py b/tests/test_select.py index 51658cdc1..d4caffb12 100644 --- a/tests/test_select.py +++ b/tests/test_select.py @@ -5,6 +5,7 @@ from pytest import mark, param, raises from omegaconf import MissingMandatoryValue, OmegaConf +from omegaconf._impl import select_value from omegaconf._utils import _ensure_container from omegaconf.errors import ConfigKeyError, InterpolationKeyError @@ -265,8 +266,8 @@ def test_select_from_nested_node_with_a_relative_key( ) -> None: cfg = OmegaConf.create(inp) # select returns the same result when a key is relative independent of absolute_key flag. - assert OmegaConf.select(cfg.a, key, absolute_key=False) == expected - assert OmegaConf.select(cfg.a, key, absolute_key=True) == expected + assert select_value(cfg.a, key, absolute_key=False) == expected + assert select_value(cfg.a, key, absolute_key=True) == expected @mark.parametrize( ("key", "expected"), @@ -282,7 +283,7 @@ def test_select_from_nested_node_relative_key_interpretation( self, key: str, expected: Any ) -> None: cfg = OmegaConf.create(inp) - assert OmegaConf.select(cfg.a, key, absolute_key=False) == expected + assert select_value(cfg.a, key, absolute_key=False) == expected @mark.parametrize( ("key", "expected"), @@ -300,4 +301,4 @@ def test_select_from_nested_node_absolute_key_interpretation( self, key: str, expected: Any ) -> None: cfg = OmegaConf.create(inp) - assert OmegaConf.select(cfg.a, key, absolute_key=True) == expected + assert select_value(cfg.a, key, absolute_key=True) == expected From c31e7b859780813e9cc6b78aad2d944f771df2ca Mon Sep 17 00:00:00 2001 From: Omry Yadan Date: Sun, 11 Apr 2021 17:18:08 -0700 Subject: [PATCH 3/3] oc.deprecated support --- docs/source/custom_resolvers.rst | 58 +++++++++- docs/source/usage.rst | 8 +- news/681.feature | 1 + omegaconf/_impl.py | 4 +- omegaconf/base.py | 12 +- omegaconf/omegaconf.py | 17 ++- omegaconf/resolvers/oc/__init__.py | 41 ++++++- .../built_in_resolvers/test_oc_deprecated.py | 107 ++++++++++++++++++ 8 files changed, 232 insertions(+), 16 deletions(-) create mode 100644 news/681.feature create mode 100644 tests/interpolation/built_in_resolvers/test_oc_deprecated.py diff --git a/docs/source/custom_resolvers.rst b/docs/source/custom_resolvers.rst index b1ad14c12..3e8badf3e 100644 --- a/docs/source/custom_resolvers.rst +++ b/docs/source/custom_resolvers.rst @@ -2,6 +2,9 @@ from omegaconf import OmegaConf, DictConfig import os + import pytest + os.environ['USER'] = 'omry' + def show(x): print(f"type: {type(x).__name__}, value: {repr(x)}") @@ -131,14 +134,14 @@ the inputs themselves: Custom interpolations can also receive the following special parameters: -- ``_parent_`` : the parent node of an interpolation. +- ``_parent_``: the parent node of an interpolation. - ``_root_``: The config root. This can be achieved by adding the special parameters to the resolver signature. -Note that special parameters must be defined as named keywords (after the `*`): +Note that special parameters must be defined as named keywords (after the `*`). -In this example, we use ``_parent_`` to implement a sum function that defaults to 0 if the node does not exist. -(In contrast to the sum we defined earlier where accessing an invalid key, e.g. ``"a_plus_z": ${sum:${a}, ${z}}`` will result in an error). +In the example below, we use ``_parent_`` to implement a sum function that defaults to 0 if the node does not exist. +(In contrast to the sum we defined earlier where accessing an invalid key, e.g. ``"a_plus_z": ${sum:${a}, ${z}}`` would result in an error). .. doctest:: @@ -189,6 +192,53 @@ In such a case, the default value is converted to a string using ``str(default)` The following example falls back to default passwords when ``DB_PASSWORD`` is not defined: +.. doctest:: + + >>> cfg = OmegaConf.create( + ... { + ... "database": { + ... "password1": "${oc.env:DB_PASSWORD,password}", + ... "password2": "${oc.env:DB_PASSWORD,12345}", + ... "password3": "${oc.env:DB_PASSWORD,null}", + ... }, + ... } + ... ) + >>> # default is already a string + >>> show(cfg.database.password1) + type: str, value: 'password' + >>> # default is converted to a string automatically + >>> show(cfg.database.password2) + type: str, value: '12345' + >>> # unless it's None + >>> show(cfg.database.password3) + type: NoneType, value: None + +.. _oc.deprecated: + +oc.deprecated +^^^^^^^^^^^^^ +``oc.deprecated`` enables you to deprecate a config node. +It takes two parameters: + +- ``key``: An interpolation key representing the new key you are migrating to. This parameter is required. +- ``message``: A message to use as the warning when the config node is being accessed. The default message is + ``'$OLD_KEY' is deprecated. Change your code and config to use '$NEW_KEY'``. + +.. doctest:: + + >>> conf = OmegaConf.create({ + ... "rusty_key": "${oc.deprecated:shiny_key}", + ... "custom_msg": "${oc.deprecated:shiny_key, 'Use $NEW_KEY'}", + ... "shiny_key": 10 + ... }) + >>> # Accessing rusty_key will issue a deprecation warning + >>> # and return the new value automatically + >>> warning = "'rusty_key' is deprecated. Change your" \ + ... " code and config to use 'shiny_key'" + >>> with pytest.warns(UserWarning, match=warning): + ... assert conf.rusty_key == 10 + >>> with pytest.warns(UserWarning, match="Use shiny_key"): + ... assert conf.custom_msg == 10 .. _oc.decode: diff --git a/docs/source/usage.rst b/docs/source/usage.rst index a02c481fe..054bbe14d 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -5,7 +5,6 @@ import sys import tempfile import pickle - os.environ['USER'] = 'omry' # ensures that DB_TIMEOUT is not set in the doc. os.environ.pop('DB_TIMEOUT', None) @@ -385,9 +384,9 @@ Interpolated nodes can be any node in the config, not just leaf nodes: Resolvers ^^^^^^^^^ -You can add additional interpolation types by registering resolvers using ``OmegaConf.register_new_resolver()``. +Add new interpolation types by registering resolvers using ``OmegaConf.register_new_resolver()``. Such resolvers are called when the config node is accessed. -See :doc:`custom_resolvers` for more details, or keep reading for a minimal example. +The minimal example below shows its most basic usage, see :doc:`custom_resolvers` for more details. .. doctest:: @@ -403,7 +402,8 @@ Built-in resolvers OmegaConf comes with a set of built-in custom resolvers: * :ref:`oc.decode`: Parsing an input string using interpolation grammar -* :ref:`oc.env`: Accessing environment variables. +* :ref:`oc.deprecated`: Deprecate a key in your config +* :ref:`oc.env`: Accessing environment variables * :ref:`oc.dict.{keys,values}`: Viewing the keys or the values of a dictionary as a list diff --git a/news/681.feature b/news/681.feature new file mode 100644 index 000000000..e66dcac5f --- /dev/null +++ b/news/681.feature @@ -0,0 +1 @@ +Introduce oc.deprecated resolver, that enables deprecating config nodes diff --git a/omegaconf/_impl.py b/omegaconf/_impl.py index 3f9deba49..8ae40861c 100644 --- a/omegaconf/_impl.py +++ b/omegaconf/_impl.py @@ -63,7 +63,9 @@ def select_value( throw_on_missing=throw_on_missing, absolute_key=absolute_key, ) + if isinstance(ret, Node) and ret._is_missing(): + assert not throw_on_missing # would have raised an exception in select_node return None return _get_value(ret) @@ -106,6 +108,4 @@ def select_node( ): return default - if value is not None and value._is_missing(): - return None return value diff --git a/omegaconf/base.py b/omegaconf/base.py index 346952e4e..2314681be 100644 --- a/omegaconf/base.py +++ b/omegaconf/base.py @@ -633,7 +633,17 @@ def _evaluate_custom_resolver( resolver = OmegaConf._get_resolver(inter_type) if resolver is not None: root_node = self._get_root() - return resolver(root_node, self, inter_args, inter_args_str) + node = None + if key is not None: + node = self._get_node(key, validate_access=True) + assert node is None or isinstance(node, Node) + return resolver( + root_node, + self, + node, + inter_args, + inter_args_str, + ) else: raise UnsupportedInterpolationType( f"Unsupported interpolation type {inter_type}" diff --git a/omegaconf/omegaconf.py b/omegaconf/omegaconf.py index b08ecb9fd..825535b2d 100644 --- a/omegaconf/omegaconf.py +++ b/omegaconf/omegaconf.py @@ -95,6 +95,7 @@ def register_default_resolvers() -> None: from omegaconf.resolvers import env, oc OmegaConf.register_new_resolver("oc.decode", oc.decode) + OmegaConf.register_new_resolver("oc.deprecated", oc.deprecated) OmegaConf.register_new_resolver("oc.env", oc.env) OmegaConf.register_new_resolver("oc.dict.keys", oc.dict.keys) OmegaConf.register_new_resolver("oc.dict.values", oc.dict.values) @@ -324,7 +325,8 @@ def legacy_register_resolver(name: str, resolver: Resolver) -> None: def resolver_wrapper( config: BaseContainer, - node: BaseContainer, + parent: BaseContainer, + node: Node, args: Tuple[Any, ...], args_str: Tuple[str, ...], ) -> Any: @@ -402,11 +404,13 @@ def _should_pass(special: str) -> bool: return ret pass_parent = _should_pass("_parent_") + pass_node = _should_pass("_node_") pass_root = _should_pass("_root_") def resolver_wrapper( config: BaseContainer, parent: Container, + node: Node, args: Tuple[Any, ...], args_str: Tuple[str, ...], ) -> Any: @@ -418,9 +422,11 @@ def resolver_wrapper( pass # Call resolver. - kwargs = {} + kwargs: Dict[str, Node] = {} if pass_parent: kwargs["_parent_"] = parent + if pass_node: + kwargs["_node_"] = node if pass_root: kwargs["_root_"] = config @@ -443,7 +449,7 @@ def get_resolver( cls, name: str, ) -> Optional[ - Callable[[Container, Container, Tuple[Any, ...], Tuple[str, ...]], Any] + Callable[[Container, Container, Node, Tuple[Any, ...], Tuple[str, ...]], Any] ]: warnings.warn( "`OmegaConf.get_resolver()` is deprecated (see https://github.com/omry/omegaconf/issues/608)", @@ -878,7 +884,10 @@ def _get_obj_type(c: Any) -> Optional[Type[Any]]: def _get_resolver( name: str, ) -> Optional[ - Callable[[Container, Container, Tuple[Any, ...], Tuple[str, ...]], Any] + Callable[ + [Container, Container, Optional[Node], Tuple[Any, ...], Tuple[str, ...]], + Any, + ] ]: # noinspection PyProtectedMember return ( diff --git a/omegaconf/resolvers/oc/__init__.py b/omegaconf/resolvers/oc/__init__.py index 72b2be2b9..0d4baae4a 100644 --- a/omegaconf/resolvers/oc/__init__.py +++ b/omegaconf/resolvers/oc/__init__.py @@ -1,8 +1,11 @@ import os +import string +import warnings from typing import Any, Optional -from omegaconf import Container +from omegaconf import Container, Node from omegaconf._utils import _DEFAULT_MARKER_, _get_value +from omegaconf.errors import ConfigKeyError from omegaconf.grammar_parser import parse from omegaconf.resolvers.oc import dict @@ -46,8 +49,44 @@ def decode(expr: Optional[str], _parent_: Container) -> Any: return _get_value(val) +def deprecated( + key: str, + message: str = "'$OLD_KEY' is deprecated. Change your code and config to use '$NEW_KEY'", + *, + _parent_: Container, + _node_: Optional[Node], +) -> Any: + from omegaconf._impl import select_node + + if not isinstance(key, str): + raise TypeError( + f"oc.deprecated: interpolation key type is not a string ({type(key).__name__})" + ) + + if not isinstance(message, str): + raise TypeError( + f"oc.deprecated: interpolation message type is not a string ({type(message).__name__})" + ) + + assert _node_ is not None + full_key = _node_._get_full_key(key=None) + target_node = select_node(_parent_, key, absolute_key=True) + if target_node is None: + raise ConfigKeyError( + f"In oc.deprecated resolver at '{full_key}': Key not found: '{key}'" + ) + new_key = target_node._get_full_key(key=None) + msg = string.Template(message).safe_substitute( + OLD_KEY=full_key, + NEW_KEY=new_key, + ) + warnings.warn(category=UserWarning, message=msg) + return target_node + + __all__ = [ "decode", + "deprecated", "dict", "env", ] diff --git a/tests/interpolation/built_in_resolvers/test_oc_deprecated.py b/tests/interpolation/built_in_resolvers/test_oc_deprecated.py new file mode 100644 index 000000000..7f03ab34a --- /dev/null +++ b/tests/interpolation/built_in_resolvers/test_oc_deprecated.py @@ -0,0 +1,107 @@ +import re +from typing import Any + +from pytest import mark, param, raises, warns + +from omegaconf import OmegaConf +from omegaconf.errors import InterpolationResolutionError + + +@mark.parametrize( + ("cfg", "key", "expected_value", "expected_warning"), + [ + param( + {"a": 10, "b": "${oc.deprecated: a}"}, + "b", + 10, + "'b' is deprecated. Change your code and config to use 'a'", + id="value", + ), + param( + {"a": 10, "b": "${oc.deprecated: a, '$OLD_KEY is deprecated'}"}, + "b", + 10, + "b is deprecated", + id="value-custom-message", + ), + param( + { + "a": 10, + "b": "${oc.deprecated: a, ${warning}}", + "warning": "$OLD_KEY is bad, $NEW_KEY is good", + }, + "b", + 10, + "b is bad, a is good", + id="value-custom-message-config-variable", + ), + param( + {"a": {"b": 10}, "b": "${oc.deprecated: a}"}, + "b", + OmegaConf.create({"b": 10}), + "'b' is deprecated. Change your code and config to use 'a'", + id="dict", + ), + param( + {"a": {"b": 10}, "b": "${oc.deprecated: a}"}, + "b.b", + 10, + "'b' is deprecated. Change your code and config to use 'a'", + id="dict_value", + ), + param( + {"a": [0, 1], "b": "${oc.deprecated: a}"}, + "b", + OmegaConf.create([0, 1]), + "'b' is deprecated. Change your code and config to use 'a'", + id="list", + ), + param( + {"a": [0, 1], "b": "${oc.deprecated: a}"}, + "b[1]", + 1, + "'b' is deprecated. Change your code and config to use 'a'", + id="list_value", + ), + ], +) +def test_deprecated( + cfg: Any, key: str, expected_value: Any, expected_warning: str +) -> None: + cfg = OmegaConf.create(cfg) + with warns(UserWarning, match=re.escape(expected_warning)): + value = OmegaConf.select(cfg, key) + assert value == expected_value + assert type(value) == type(expected_value) + + +@mark.parametrize( + ("cfg", "error"), + [ + param( + {"a": "${oc.deprecated: z}"}, + "ConfigKeyError raised while resolving interpolation:" + " In oc.deprecated resolver at 'a': Key not found: 'z'", + id="target_not_found", + ), + param( + {"a": "${oc.deprecated: 111111}"}, + "TypeError raised while resolving interpolation: oc.deprecated:" + " interpolation key type is not a string (int)", + id="invalid_key_type", + ), + param( + {"a": "${oc.deprecated: b, 1000}", "b": 10}, + "TypeError raised while resolving interpolation: oc.deprecated:" + " interpolation message type is not a string (int)", + id="invalid_message_type", + ), + ], +) +def test_deprecated_target_not_found(cfg: Any, error: str) -> None: + cfg = OmegaConf.create(cfg) + with raises( + InterpolationResolutionError, + match=re.escape(error), + ): + cfg.a