Skip to content

Commit

Permalink
Revert conversion of dicts/lists output by interpolations
Browse files Browse the repository at this point in the history
They were being converted into DictConfig/ListConfig within
`_node_wrap()`. Now we only call `_node_wrap()` on primitive types,
while other types are stored within an `AnyNode` with the
`allow_objects` flag set to True.
  • Loading branch information
odelalleau committed Apr 16, 2021
1 parent dad564e commit abc60ae
Show file tree
Hide file tree
Showing 7 changed files with 127 additions and 161 deletions.
17 changes: 11 additions & 6 deletions docs/notebook/Tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -821,14 +821,19 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Strings may be converted using ``oc.decode``:\n",
"With ``oc.decode``, strings can be converted into their corresponding data types using the OmegaConf grammar.\n",
"This grammar recognizes typical data types like ``bool``, ``int``, ``float``, ``dict`` and ``list``,\n",
"e.g. ``\"true\"``, ``\"1\"``, ``\"1e-3\"``, ``\"{a: b}\"``, ``\"[a, b, c]\"``.\n",
"It will also resolve interpolations like ``\"${foo}\"``, returning the corresponding value of the node.\n",
"\n",
"- Primitive values (e.g., ``\"true\"``, ``\"1\"``, ``\"1e-3\"``) are automatically converted to their corresponding type (bool, int, float)\n",
"- Dictionaries and lists (e.g., ``\"{a: b}\"``, ``\"[a, b, c]\"``) are returned as transient config nodes (DictConfig and ListConfig)\n",
"- Interpolations (e.g., ``\"${foo}\"``) are automatically resolved\n",
"- ``None`` is the only valid non-string input to ``oc.decode`` (returning ``None`` in that case)\n",
"Note that:\n",
"\n",
"This can be useful for instance to parse environment variables:"
"- When providing as input to ``oc.decode`` a string that is meant to be decoded into another string, in general\n",
" the input string should be quoted (since only a subset of characters are allowed by the grammar in unquoted\n",
" strings). For instance, a proper string interpolation could be: ``\"'Hi! My name is: ${name}'\"`` (with extra quotes).\n",
"- ``None`` (written as ``null`` in the grammar) is the only valid non-string input to ``oc.decode`` (returning ``None`` in that case)\n",
"\n",
"This resolver can be useful for instance to parse environment variables:"
]
},
{
Expand Down
21 changes: 13 additions & 8 deletions docs/source/custom_resolvers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -245,14 +245,19 @@ It takes two parameters:
oc.decode
^^^^^^^^^

Strings may be converted using ``oc.decode``:
With ``oc.decode``, strings can be converted into their corresponding data types using the OmegaConf grammar.
This grammar recognizes typical data types like ``bool``, ``int``, ``float``, ``dict`` and ``list``,
e.g. ``"true"``, ``"1"``, ``"1e-3"``, ``"{a: b}"``, ``"[a, b, c]"``.
It will also resolve interpolations like ``"${foo}"``, returning the corresponding value of the node.

- Primitive values (e.g., ``"true"``, ``"1"``, ``"1e-3"``) are automatically converted to their corresponding type (bool, int, float)
- Dictionaries and lists (e.g., ``"{a: b}"``, ``"[a, b, c]"``) are returned as transient config nodes (DictConfig and ListConfig)
- Interpolations (e.g., ``"${foo}"``) are automatically resolved
- ``None`` is the only valid non-string input to ``oc.decode`` (returning ``None`` in that case)
Note that:

This can be useful for instance to parse environment variables:
- When providing as input to ``oc.decode`` a string that is meant to be decoded into another string, in general
the input string should be quoted (since only a subset of characters are allowed by the grammar in unquoted
strings). For instance, a proper string interpolation could be: ``"'Hi! My name is: ${name}'"`` (with extra quotes).
- ``None`` (written as ``null`` in the grammar) is the only valid non-string input to ``oc.decode`` (returning ``None`` in that case)

This resolver can be useful for instance to parse environment variables:

.. doctest::

Expand All @@ -269,8 +274,8 @@ This can be useful for instance to parse environment variables:
>>> show(cfg.database.port) # converted to int
type: int, value: 3308
>>> os.environ["DB_NODES"] = "[host1, host2, host3]"
>>> show(cfg.database.nodes) # converted to a ListConfig
type: ListConfig, value: ['host1', 'host2', 'host3']
>>> show(cfg.database.nodes) # converted to a Python list
type: list, value: ['host1', 'host2', 'host3']
>>> show(cfg.database.timeout) # keeping `None` as is
type: NoneType, value: None
>>> os.environ["DB_TIMEOUT"] = "${.port}"
Expand Down
2 changes: 1 addition & 1 deletion news/488.api_change
Original file line number Diff line number Diff line change
@@ -1 +1 @@
When resolving an interpolation of a typed config value, the interpolated value is validated and possibly converted based on the node's type.
When resolving an interpolation of a config value with a primitive type, the interpolated value is validated and possibly converted based on the node's type.
49 changes: 27 additions & 22 deletions omegaconf/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
_is_missing_value,
format_and_raise,
get_value_kind,
is_primitive_type,
split_key,
)
from .errors import (
Expand All @@ -23,7 +24,6 @@
InterpolationResolutionError,
InterpolationToMissingValueError,
InterpolationValidationError,
KeyValidationError,
MissingMandatoryValue,
UnsupportedInterpolationType,
ValidationError,
Expand Down Expand Up @@ -552,31 +552,36 @@ def _wrap_interpolation_result(
throw_on_resolution_failure: bool,
) -> Optional["Node"]:
from .basecontainer import BaseContainer
from .omegaconf import _node_wrap
from .nodes import AnyNode
from .omegaconf import _node_wrap, flag_override

assert parent is None or isinstance(parent, BaseContainer)
try:
wrapped = _node_wrap(
type_=value._metadata.ref_type,
parent=parent,
is_optional=value._metadata.optional,
value=resolved,
key=key,
ref_type=value._metadata.ref_type,
)
except (KeyValidationError, ValidationError) as e:
if throw_on_resolution_failure:
self._format_and_raise(
key=key,

if is_primitive_type(type(resolved)):
# Primitive types get wrapped using the same logic as when setting the
# value of a node (i.e., through `_node_wrap()`).
try:
wrapped = _node_wrap(
type_=value._metadata.ref_type,
parent=parent,
is_optional=value._metadata.optional,
value=resolved,
cause=e,
type_override=InterpolationValidationError,
key=key,
ref_type=value._metadata.ref_type,
)
return None
# Since we created a new node on the fly, future changes to this node are
# likely to be lost. We thus set the "readonly" flag to `True` to reduce
# the risk of accidental modifications.
wrapped._set_flag("readonly", True)
except ValidationError: # pragma: no cover
# This is not supposed to happen because primitive types that must
# be wrapped should have already been validated inside
# `_validate_and_convert_interpolation_result()`.
assert False
else:
# Other objects get wrapped into an `AnyNode` with `allow_objects` set
# to True.
wrapped = AnyNode(value=None, key=key, parent=parent)
wrapped._set_flag("allow_objects", True)
with flag_override(wrapped, "readonly", False):
wrapped._set_value(resolved)

return wrapped

def _validate_not_dereferencing_to_parent(self, node: Node, target: Node) -> None:
Expand Down
35 changes: 10 additions & 25 deletions tests/interpolation/built_in_resolvers/test_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,33 +234,18 @@ def test_dict_values_dictconfig_resolver_output(
assert cfg.foo[key] == expected


@mark.parametrize(
("make_resolver", "expected_value", "expected_content"),
[
param(
lambda _parent_: OmegaConf.create({"a": 0, "b": 1}, parent=_parent_),
[0, 1],
["${y.a}", "${y.b}"],
id="dictconfig_with_parent",
),
param(
lambda: {"a": 0, "b": 1},
[0, 1],
["${y.a}", "${y.b}"],
id="plain_dict",
),
],
)
def test_dict_values_transient_interpolation(
@mark.parametrize("dict_func", ["oc.dict.values", "oc.dict.keys"])
def test_extract_from_dict_resolver_output(
restore_resolvers: Any,
make_resolver: Any,
expected_value: Any,
expected_content: Any,
dict_func: str,
) -> None:
OmegaConf.register_new_resolver("make", make_resolver)
cfg = OmegaConf.create({"x": "${oc.dict.values:y}", "y": "${make:}"})
assert cfg.x == expected_value
assert cfg.x._content == expected_content
OmegaConf.register_new_resolver("make_dict", lambda: {"a": 0, "b": 1})
cfg = OmegaConf.create({"x": f"${{{dict_func}:y}}", "y": "${make_dict:}"})
with raises(
InterpolationResolutionError,
match="TypeError raised while resolving interpolation",
):
cfg.x


def test_dict_values_are_typed() -> None:
Expand Down
57 changes: 21 additions & 36 deletions tests/interpolation/test_custom_resolvers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import random
import re
from typing import Any, Dict, List
from typing import Any

from pytest import mark, param, raises, warns

from omegaconf import DictConfig, ListConfig, OmegaConf, Resolver
from omegaconf import OmegaConf, Resolver
from omegaconf.nodes import AnyNode
from tests.interpolation import dereference_node


Expand Down Expand Up @@ -345,44 +346,28 @@ def test_clear_cache(restore_resolvers: Any) -> None:
assert old != c.k


@mark.parametrize(
("cfg", "expected"),
[
({"a": 0, "b": 1}, {"a": 0, "b": 1}),
({"a": "${y}"}, {"a": -1}),
({"a": 0, "b": "${x.a}"}, {"a": 0, "b": 0}),
({"a": 0, "b": "${.a}"}, {"a": 0, "b": 0}),
({"a": "${..y}"}, {"a": -1}),
],
)
def test_resolver_output_dict_to_dictconfig(
restore_resolvers: Any, cfg: Dict[str, Any], expected: Dict[str, Any]
) -> None:
OmegaConf.register_new_resolver("dict", lambda: cfg)
@mark.parametrize("readonly", [True, False, None])
def test_resolver_output_dict(restore_resolvers: Any, readonly: bool) -> None:
some_dict = {"a": 0, "b": "${y}"}
OmegaConf.register_new_resolver("dict", lambda: some_dict)
c = OmegaConf.create({"x": "${dict:}", "y": -1})
assert isinstance(c.x, DictConfig)
assert c.x == expected
assert dereference_node(c, "x")._get_flag("readonly")
c._set_flag("readonly", readonly)
assert c.x == some_dict
x_node = dereference_node(c, "x")
assert isinstance(x_node, AnyNode)
assert x_node._get_flag("allow_objects")


@mark.parametrize(
("cfg", "expected"),
[
([0, 1], [0, 1]),
(["${y}"], [-1]),
([0, "${x.0}"], [0, 0]),
([0, "${.0}"], [0, 0]),
(["${..y}"], [-1]),
],
)
def test_resolver_output_list_to_listconfig(
restore_resolvers: Any, cfg: List[Any], expected: List[Any]
) -> None:
OmegaConf.register_new_resolver("list", lambda: cfg)
@mark.parametrize("readonly", [True, False, None])
def test_resolver_output_list(restore_resolvers: Any, readonly: bool) -> None:
some_list = ["a", 0, "${y}"]
OmegaConf.register_new_resolver("list", lambda: some_list)
c = OmegaConf.create({"x": "${list:}", "y": -1})
assert isinstance(c.x, ListConfig)
assert c.x == expected
assert dereference_node(c, "x")._get_flag("readonly")
c._set_flag("readonly", readonly)
assert c.x == some_list
x_node = dereference_node(c, "x")
assert isinstance(x_node, AnyNode)
assert x_node._get_flag("allow_objects")


def test_register_cached_resolver_with_keyword_unsupported() -> None:
Expand Down
Loading

0 comments on commit abc60ae

Please sign in to comment.