diff --git a/docs/notebook/Tutorial.ipynb b/docs/notebook/Tutorial.ipynb index 7088d41d8..772ae00e6 100644 --- a/docs/notebook/Tutorial.ipynb +++ b/docs/notebook/Tutorial.ipynb @@ -821,14 +821,19 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Strings may be converted using ``oc.decode``:\n", + "With ``oc.decode``, strings can be converted into their corresponding data types using the OmegaConf grammar.\n", + "This grammar recognizes typical data types like ``bool``, ``int``, ``float``, ``dict`` and ``list``,\n", + "e.g. ``\"true\"``, ``\"1\"``, ``\"1e-3\"``, ``\"{a: b}\"``, ``\"[a, b, c]\"``.\n", + "It will also resolve interpolations like ``\"${foo}\"``, returning the corresponding value of the node.\n", "\n", - "- Primitive values (e.g., ``\"true\"``, ``\"1\"``, ``\"1e-3\"``) are automatically converted to their corresponding type (bool, int, float)\n", - "- Dictionaries and lists (e.g., ``\"{a: b}\"``, ``\"[a, b, c]\"``) are returned as transient config nodes (DictConfig and ListConfig)\n", - "- Interpolations (e.g., ``\"${foo}\"``) are automatically resolved\n", - "- ``None`` is the only valid non-string input to ``oc.decode`` (returning ``None`` in that case)\n", + "Note that:\n", "\n", - "This can be useful for instance to parse environment variables:" + "- When providing as input to ``oc.decode`` a string that is meant to be decoded into another string, in general\n", + " the input string should be quoted (since only a subset of characters are allowed by the grammar in unquoted\n", + " strings). For instance, a proper string interpolation could be: ``\"'Hi! My name is: ${name}'\"`` (with extra quotes).\n", + "- ``None`` (written as ``null`` in the grammar) is the only valid non-string input to ``oc.decode`` (returning ``None`` in that case)\n", + "\n", + "This resolver can be useful for instance to parse environment variables:" ] }, { diff --git a/docs/source/custom_resolvers.rst b/docs/source/custom_resolvers.rst index 3e8badf3e..fe38f6998 100644 --- a/docs/source/custom_resolvers.rst +++ b/docs/source/custom_resolvers.rst @@ -60,19 +60,6 @@ simply use quotes to bypass character limitations in strings. 'Hello, World' -Custom resolvers can return lists or dictionaries, that are automatically converted into DictConfig and ListConfig: - -.. doctest:: - - >>> OmegaConf.register_new_resolver( - ... "min_max", lambda *a: {"min": min(a), "max": max(a)} - ... ) - >>> c = OmegaConf.create({'stats': '${min_max: -1, 3, 2, 5, -10}'}) - >>> assert isinstance(c.stats, DictConfig) - >>> c.stats.min, c.stats.max - (-10, 5) - - You can take advantage of nested interpolations to perform custom operations over variables: .. doctest:: @@ -213,6 +200,37 @@ The following example falls back to default passwords when ``DB_PASSWORD`` is no >>> show(cfg.database.password3) type: NoneType, value: None + +.. _oc.create: + +oc.create +^^^^^^^^^ + +``oc.create`` may be used for dynamic generation of config nodes +(typically from Python ``dict`` / ``list`` objects or YAML strings, similar to :ref:`OmegaConf.create`). + +.. doctest:: + + + >>> OmegaConf.register_new_resolver("make_dict", lambda: {"a": 10}) + >>> cfg = OmegaConf.create( + ... { + ... "plain_dict": "${make_dict:}", + ... "dict_config": "${oc.create:${make_dict:}}", + ... "dict_config_env": "${oc.create:${oc.env:YAML_ENV}}", + ... } + ... ) + >>> os.environ["YAML_ENV"] = "A: 10\nb: 20\nC: ${.A}" + >>> show(cfg.plain_dict) # `make_dict` returns a Python dict + type: dict, value: {'a': 10} + >>> show(cfg.dict_config) # `oc.create` converts it to DictConfig + type: DictConfig, value: {'a': 10} + >>> show(cfg.dict_config_env) # YAML string to DictConfig + type: DictConfig, value: {'A': 10, 'b': 20, 'C': '${.A}'} + >>> cfg.dict_config_env.C # interpolations work in a DictConfig + 10 + + .. _oc.deprecated: oc.deprecated @@ -245,14 +263,16 @@ It takes two parameters: oc.decode ^^^^^^^^^ -Strings may be converted using ``oc.decode``: +With ``oc.decode``, strings can be converted into their corresponding data types using the OmegaConf grammar. +This grammar recognizes typical data types like ``bool``, ``int``, ``float``, ``dict`` and ``list``, +e.g. ``"true"``, ``"1"``, ``"1e-3"``, ``"{a: b}"``, ``"[a, b, c]"``. + +Note that: -- Primitive values (e.g., ``"true"``, ``"1"``, ``"1e-3"``) are automatically converted to their corresponding type (bool, int, float) -- Dictionaries and lists (e.g., ``"{a: b}"``, ``"[a, b, c]"``) are returned as transient config nodes (DictConfig and ListConfig) -- Interpolations (e.g., ``"${foo}"``) are automatically resolved -- ``None`` is the only valid non-string input to ``oc.decode`` (returning ``None`` in that case) +- In general input strings provided to ``oc.decode`` should be quoted, since only a subset of the characters is allowed in unquoted strings +- ``None`` (written as ``null`` in the grammar) is the only valid non-string input to ``oc.decode`` (returning ``None`` in that case) -This can be useful for instance to parse environment variables: +This resolver can be useful for instance to parse environment variables: .. doctest:: @@ -269,8 +289,8 @@ This can be useful for instance to parse environment variables: >>> show(cfg.database.port) # converted to int type: int, value: 3308 >>> os.environ["DB_NODES"] = "[host1, host2, host3]" - >>> show(cfg.database.nodes) # converted to a ListConfig - type: ListConfig, value: ['host1', 'host2', 'host3'] + >>> show(cfg.database.nodes) # converted to a Python list + type: list, value: ['host1', 'host2', 'host3'] >>> show(cfg.database.timeout) # keeping `None` as is type: NoneType, value: None >>> os.environ["DB_TIMEOUT"] = "${.port}" diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 054bbe14d..318b23114 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -22,6 +22,8 @@ Just pip install:: OmegaConf requires Python 3.6 and newer. +.. _creating: + Creating -------- You can create OmegaConf objects from multiple sources. @@ -401,6 +403,7 @@ Built-in resolvers ^^^^^^^^^^^^^^^^^^ OmegaConf comes with a set of built-in custom resolvers: +* :ref:`oc.create`: Dynamically generating config nodes * :ref:`oc.decode`: Parsing an input string using interpolation grammar * :ref:`oc.deprecated`: Deprecate a key in your config * :ref:`oc.env`: Accessing environment variables diff --git a/news/488.api_change b/news/488.api_change index a7b269636..3799c2e67 100644 --- a/news/488.api_change +++ b/news/488.api_change @@ -1 +1 @@ -When resolving an interpolation of a typed config value, the interpolated value is validated and possibly converted based on the node's type. +When resolving an interpolation of a config value with a primitive type, the interpolated value is validated and possibly converted based on the node's type. diff --git a/news/645.feature b/news/645.feature new file mode 100644 index 000000000..d03e889c9 --- /dev/null +++ b/news/645.feature @@ -0,0 +1 @@ +The new built-in resolver `oc.create` can be used to dynamically generate config nodes diff --git a/omegaconf/base.py b/omegaconf/base.py index 2314681be..4f4d14296 100644 --- a/omegaconf/base.py +++ b/omegaconf/base.py @@ -15,6 +15,7 @@ _is_missing_value, format_and_raise, get_value_kind, + is_primitive_type, split_key, ) from .errors import ( @@ -23,7 +24,6 @@ InterpolationResolutionError, InterpolationToMissingValueError, InterpolationValidationError, - KeyValidationError, MissingMandatoryValue, UnsupportedInterpolationType, ValidationError, @@ -552,10 +552,14 @@ def _wrap_interpolation_result( throw_on_resolution_failure: bool, ) -> Optional["Node"]: from .basecontainer import BaseContainer + from .nodes import AnyNode from .omegaconf import _node_wrap assert parent is None or isinstance(parent, BaseContainer) - try: + + if is_primitive_type(type(resolved)): + # Primitive types get wrapped using `_node_wrap()`, ensuring value is + # validated and potentially converted. wrapped = _node_wrap( type_=value._metadata.ref_type, parent=parent, @@ -564,19 +568,14 @@ def _wrap_interpolation_result( key=key, ref_type=value._metadata.ref_type, ) - except (KeyValidationError, ValidationError) as e: - if throw_on_resolution_failure: - self._format_and_raise( - key=key, - value=resolved, - cause=e, - type_override=InterpolationValidationError, - ) - return None - # Since we created a new node on the fly, future changes to this node are - # likely to be lost. We thus set the "readonly" flag to `True` to reduce - # the risk of accidental modifications. - wrapped._set_flag("readonly", True) + else: + # Other objects get wrapped into an `AnyNode` with `allow_objects` set + # to True. + wrapped = AnyNode( + value=resolved, key=key, parent=None, flags={"allow_objects": True} + ) + wrapped._set_parent(parent) + return wrapped def _validate_not_dereferencing_to_parent(self, node: Node, target: Node) -> None: diff --git a/omegaconf/nodes.py b/omegaconf/nodes.py index 0f71b029b..400f531f9 100644 --- a/omegaconf/nodes.py +++ b/omegaconf/nodes.py @@ -111,11 +111,14 @@ def __init__( value: Any = None, key: Any = None, parent: Optional[Container] = None, + flags: Optional[Dict[str, bool]] = None, ): super().__init__( parent=parent, value=value, - metadata=Metadata(ref_type=Any, object_type=None, key=key, optional=True), + metadata=Metadata( + ref_type=Any, object_type=None, key=key, optional=True, flags=flags + ), ) def _validate_and_convert_impl(self, value: Any) -> Any: @@ -145,12 +148,17 @@ def __init__( key: Any = None, parent: Optional[Container] = None, is_optional: bool = True, + flags: Optional[Dict[str, bool]] = None, ): super().__init__( parent=parent, value=value, metadata=Metadata( - key=key, optional=is_optional, ref_type=str, object_type=str + key=key, + optional=is_optional, + ref_type=str, + object_type=str, + flags=flags, ), ) @@ -174,12 +182,17 @@ def __init__( key: Any = None, parent: Optional[Container] = None, is_optional: bool = True, + flags: Optional[Dict[str, bool]] = None, ): super().__init__( parent=parent, value=value, metadata=Metadata( - key=key, optional=is_optional, ref_type=int, object_type=int + key=key, + optional=is_optional, + ref_type=int, + object_type=int, + flags=flags, ), ) @@ -206,12 +219,17 @@ def __init__( key: Any = None, parent: Optional[Container] = None, is_optional: bool = True, + flags: Optional[Dict[str, bool]] = None, ): super().__init__( parent=parent, value=value, metadata=Metadata( - key=key, optional=is_optional, ref_type=float, object_type=float + key=key, + optional=is_optional, + ref_type=float, + object_type=float, + flags=flags, ), ) @@ -255,12 +273,17 @@ def __init__( key: Any = None, parent: Optional[Container] = None, is_optional: bool = True, + flags: Optional[Dict[str, bool]] = None, ): super().__init__( parent=parent, value=value, metadata=Metadata( - key=key, optional=is_optional, ref_type=bool, object_type=bool + key=key, + optional=is_optional, + ref_type=bool, + object_type=bool, + flags=flags, ), ) @@ -307,6 +330,7 @@ def __init__( key: Any = None, parent: Optional[Container] = None, is_optional: bool = True, + flags: Optional[Dict[str, bool]] = None, ): if not isinstance(enum_type, type) or not issubclass(enum_type, Enum): raise ValidationError( @@ -320,7 +344,11 @@ def __init__( parent=parent, value=value, metadata=Metadata( - key=key, optional=is_optional, ref_type=enum_type, object_type=enum_type + key=key, + optional=is_optional, + ref_type=enum_type, + object_type=enum_type, + flags=flags, ), ) diff --git a/omegaconf/omegaconf.py b/omegaconf/omegaconf.py index 825535b2d..be555262b 100644 --- a/omegaconf/omegaconf.py +++ b/omegaconf/omegaconf.py @@ -94,6 +94,7 @@ def SI(interpolation: str) -> Any: def register_default_resolvers() -> None: from omegaconf.resolvers import env, oc + OmegaConf.register_new_resolver("oc.create", oc.create) OmegaConf.register_new_resolver("oc.decode", oc.decode) OmegaConf.register_new_resolver("oc.deprecated", oc.deprecated) OmegaConf.register_new_resolver("oc.env", oc.env) diff --git a/omegaconf/resolvers/oc/__init__.py b/omegaconf/resolvers/oc/__init__.py index 0d4baae4a..90a779982 100644 --- a/omegaconf/resolvers/oc/__init__.py +++ b/omegaconf/resolvers/oc/__init__.py @@ -5,11 +5,20 @@ from omegaconf import Container, Node from omegaconf._utils import _DEFAULT_MARKER_, _get_value +from omegaconf.basecontainer import BaseContainer from omegaconf.errors import ConfigKeyError from omegaconf.grammar_parser import parse from omegaconf.resolvers.oc import dict +def create(obj: Any, _parent_: Container) -> Any: + """Create a config object from `obj`, similar to `OmegaConf.create`""" + from omegaconf import OmegaConf + + assert isinstance(_parent_, BaseContainer) + return OmegaConf.create(obj, parent=_parent_) + + def env(key: str, default: Any = _DEFAULT_MARKER_) -> Optional[str]: """ :param key: Environment variable key @@ -85,6 +94,7 @@ def deprecated( __all__ = [ + "create", "decode", "deprecated", "dict", diff --git a/tests/interpolation/built_in_resolvers/test_create_resolver.py b/tests/interpolation/built_in_resolvers/test_create_resolver.py new file mode 100644 index 000000000..c9c239b9f --- /dev/null +++ b/tests/interpolation/built_in_resolvers/test_create_resolver.py @@ -0,0 +1,139 @@ +from typing import Any, Dict, List + +from pytest import mark, param, raises + +from omegaconf import DictConfig, ListConfig, OmegaConf +from omegaconf.errors import InterpolationResolutionError + + +@mark.parametrize( + ("cfg", "key", "expected"), + [ + # Note that since `oc.create` is simply calling `OmegaConf.create`, which is + # already thoroughly tested, we do not do extensive tests here. + param( + {"x": "${oc.create:{a: 0, b: 1}}"}, + "x", + OmegaConf.create({"a": 0, "b": 1}), + id="dict", + ), + param( + {"x": "${oc.create:[0, 1, 2]}"}, + "x", + OmegaConf.create([0, 1, 2]), + id="list", + ), + param( + {"x": "${oc.create:{a: 0, b: ${y}}}", "y": 5}, + "x", + OmegaConf.create({"a": 0, "b": 5}), + id="dict:interpolated_value", + ), + param( + {"x": "${oc.create:[0, 1, ${y}]}", "y": 5}, + "x", + OmegaConf.create([0, 1, 5]), + id="list:interpolated_value", + ), + param( + {"x": "${oc.create:${y}}", "y": {"a": 0}}, + "x", + OmegaConf.create({"a": 0}), + id="dict:interpolated_node", + ), + param( + {"x": "${oc.create:${y}}", "y": [0, 1]}, + "x", + OmegaConf.create([0, 1]), + id="list:interpolated_node", + ), + ], +) +def test_create(cfg: Any, key: str, expected: Any) -> None: + cfg = OmegaConf.create(cfg) + val = cfg[key] + assert val == expected + assert type(val) is type(expected) + assert val._get_flag("readonly") is None + + +def test_create_error() -> None: + cfg = OmegaConf.create({"x": "${oc.create:0}"}) + with raises(InterpolationResolutionError, match="ValidationError"): + cfg.x + + +def test_write_into_output() -> None: + cfg = OmegaConf.create( + { + "x": "${oc.create:${y}}", + "y": { + "a": 0, + "b": {"c": 1}, + }, + } + ) + x = cfg.x + assert x._get_flag("readonly") is None + + # Write into the node generated by `oc.create`. + x.a = 1 + x.b.c = 2 + + # The node that we wrote into should be modified. + assert x.a == 1 + assert x.b.c == 2 + + # The interpolated node should not be modified. + assert cfg.y.a == 0 + assert cfg.y.b.c == 1 + + # Re-accessing the node "forgets" the changes. + assert cfg.x.a == 0 + assert cfg.x.b.c == 1 + + +@mark.parametrize( + ("cfg", "expected"), + [ + ({"a": 0, "b": 1}, {"a": 0, "b": 1}), + ({"a": "${y}"}, {"a": -1}), + ({"a": 0, "b": "${x.a}"}, {"a": 0, "b": 0}), + ({"a": 0, "b": "${.a}"}, {"a": 0, "b": 0}), + ({"a": "${..y}"}, {"a": -1}), + ], +) +def test_resolver_output_dict_to_dictconfig( + restore_resolvers: Any, cfg: Dict[str, Any], expected: Dict[str, Any] +) -> None: + OmegaConf.register_new_resolver("dict", lambda: cfg) + c = OmegaConf.create({"x": "${oc.create:${dict:}}", "y": -1}) + assert isinstance(c.x, DictConfig) + assert c.x == expected + assert c.x._parent is c + + +@mark.parametrize( + ("cfg", "expected"), + [ + ([0, 1], [0, 1]), + (["${y}"], [-1]), + ([0, "${x.0}"], [0, 0]), + ([0, "${.0}"], [0, 0]), + (["${..y}"], [-1]), + ], +) +def test_resolver_output_list_to_listconfig( + restore_resolvers: Any, cfg: List[Any], expected: List[Any] +) -> None: + OmegaConf.register_new_resolver("list", lambda: cfg) + c = OmegaConf.create({"x": "${oc.create:${list:}}", "y": -1}) + assert isinstance(c.x, ListConfig) + assert c.x == expected + assert c.x._parent is c + + +def test_merge_into_created_node() -> None: + cfg: Any = OmegaConf.create({"x": "${oc.create:{y: 0}}"}) + cfg = OmegaConf.merge(cfg, {"x": {"z": 1}}) + assert cfg == {"x": {"y": 0, "z": 1}} diff --git a/tests/interpolation/built_in_resolvers/test_dict.py b/tests/interpolation/built_in_resolvers/test_dict.py index d988ec4ec..8757f4ec3 100644 --- a/tests/interpolation/built_in_resolvers/test_dict.py +++ b/tests/interpolation/built_in_resolvers/test_dict.py @@ -234,35 +234,6 @@ def test_dict_values_dictconfig_resolver_output( assert cfg.foo[key] == expected -@mark.parametrize( - ("make_resolver", "expected_value", "expected_content"), - [ - param( - lambda _parent_: OmegaConf.create({"a": 0, "b": 1}, parent=_parent_), - [0, 1], - ["${y.a}", "${y.b}"], - id="dictconfig_with_parent", - ), - param( - lambda: {"a": 0, "b": 1}, - [0, 1], - ["${y.a}", "${y.b}"], - id="plain_dict", - ), - ], -) -def test_dict_values_transient_interpolation( - restore_resolvers: Any, - make_resolver: Any, - expected_value: Any, - expected_content: Any, -) -> None: - OmegaConf.register_new_resolver("make", make_resolver) - cfg = OmegaConf.create({"x": "${oc.dict.values:y}", "y": "${make:}"}) - assert cfg.x == expected_value - assert cfg.x._content == expected_content - - def test_dict_values_are_typed() -> None: cfg = OmegaConf.create( { @@ -289,7 +260,7 @@ def test_dict_values_are_typed() -> None: ) def test_readonly_parent(cfg: Any, expected: Any) -> None: cfg = OmegaConf.create(cfg) - cfg._set_flag("readonly", True) + OmegaConf.set_readonly(cfg, True) assert cfg.x == expected diff --git a/tests/interpolation/test_custom_resolvers.py b/tests/interpolation/test_custom_resolvers.py index 42c7ce0ec..15292f555 100644 --- a/tests/interpolation/test_custom_resolvers.py +++ b/tests/interpolation/test_custom_resolvers.py @@ -1,10 +1,11 @@ import random import re -from typing import Any, Dict, List +from typing import Any from pytest import mark, param, raises, warns -from omegaconf import DictConfig, ListConfig, OmegaConf, Resolver +from omegaconf import OmegaConf, Resolver +from omegaconf.nodes import AnyNode from tests.interpolation import dereference_node @@ -345,44 +346,40 @@ def test_clear_cache(restore_resolvers: Any) -> None: assert old != c.k -@mark.parametrize( - ("cfg", "expected"), - [ - ({"a": 0, "b": 1}, {"a": 0, "b": 1}), - ({"a": "${y}"}, {"a": -1}), - ({"a": 0, "b": "${x.a}"}, {"a": 0, "b": 0}), - ({"a": 0, "b": "${.a}"}, {"a": 0, "b": 0}), - ({"a": "${..y}"}, {"a": -1}), - ], -) -def test_resolver_output_dict_to_dictconfig( - restore_resolvers: Any, cfg: Dict[str, Any], expected: Dict[str, Any] -) -> None: - OmegaConf.register_new_resolver("dict", lambda: cfg) +@mark.parametrize("readonly", [True, False]) +def test_resolver_output_dict(restore_resolvers: Any, readonly: bool) -> None: + some_dict = {"a": 0, "b": "${y}"} + OmegaConf.register_new_resolver("dict", lambda: some_dict) c = OmegaConf.create({"x": "${dict:}", "y": -1}) - assert isinstance(c.x, DictConfig) - assert c.x == expected - assert dereference_node(c, "x")._get_flag("readonly") + OmegaConf.set_readonly(c, readonly) + assert isinstance(c.x, dict) + assert c.x == some_dict + x_node = dereference_node(c, "x") + assert isinstance(x_node, AnyNode) + assert x_node._get_flag("allow_objects") +@mark.parametrize("readonly", [True, False]) @mark.parametrize( - ("cfg", "expected"), + ("data", "expected_type"), [ - ([0, 1], [0, 1]), - (["${y}"], [-1]), - ([0, "${x.0}"], [0, 0]), - ([0, "${.0}"], [0, 0]), - (["${..y}"], [-1]), + param({"a": 0, "b": "${y}"}, dict, id="dict"), + param(["a", 0, "${y}"], list, id="list"), ], ) -def test_resolver_output_list_to_listconfig( - restore_resolvers: Any, cfg: List[Any], expected: List[Any] +def test_resolver_output_plain_dict_list( + restore_resolvers: Any, readonly: bool, data: Any, expected_type: type ) -> None: - OmegaConf.register_new_resolver("list", lambda: cfg) - c = OmegaConf.create({"x": "${list:}", "y": -1}) - assert isinstance(c.x, ListConfig) - assert c.x == expected - assert dereference_node(c, "x")._get_flag("readonly") + OmegaConf.register_new_resolver("get_data", lambda: data) + c = OmegaConf.create({"x": "${get_data:}", "y": -1}) + OmegaConf.set_readonly(c, readonly) + + assert isinstance(c.x, expected_type) + assert c.x == data + + x_node = dereference_node(c, "x") + assert isinstance(x_node, AnyNode) + assert x_node._get_flag("allow_objects") def test_register_cached_resolver_with_keyword_unsupported() -> None: diff --git a/tests/interpolation/test_interpolation.py b/tests/interpolation/test_interpolation.py index 8ef0a24c3..376257164 100644 --- a/tests/interpolation/test_interpolation.py +++ b/tests/interpolation/test_interpolation.py @@ -26,7 +26,6 @@ # The above comment is a statement to stop DeepCode from raising a warning on # lines that do equality checks of the form # c.k == c.k -from tests.interpolation import dereference_node def test_interpolation_with_missing() -> None: @@ -279,33 +278,19 @@ def test_none_value_in_quoted_string(restore_resolvers: Any) -> None: id="convert_str_to_int", ), param( - MissingList(list=SI("${identity:[a, b, c]}")), + MissingList(list=SI("${oc.create:[a, b, c]}")), "list", ["a", "b", "c"], ListConfig, id="list_str", ), param( - MissingList(list=SI("${identity:[0, 1, 2]}")), - "list", - ["0", "1", "2"], - ListConfig, - id="list_int_to_str", - ), - param( - MissingDict(dict=SI("${identity:{key1: val1, key2: val2}}")), + MissingDict(dict=SI("${oc.create:{key1: val1, key2: val2}}")), "dict", {"key1": "val1", "key2": "val2"}, DictConfig, id="dict_str", ), - param( - MissingDict(dict=SI("${identity:{a: 0, b: 1}}")), - "dict", - {"a": "0", "b": "1"}, - DictConfig, - id="dict_int_to_str", - ), ], ) def test_interpolation_type_validated_ok( @@ -367,24 +352,6 @@ def drop_last(s: str) -> str: ), id="non_optional_node_interpolation", ), - param( - SubscriptedList(list=SI("${identity:[a, b]}")), - "list", - raises( - InterpolationValidationError, - match=re.escape("Value 'a' could not be converted to Integer"), - ), - id="list_type_mismatch", - ), - param( - MissingDict(dict=SI("${identity:{0: b, 1: d}}")), - "dict", - raises( - InterpolationValidationError, - match=re.escape("Key 0 (int) is incompatible with (str)"), - ), - id="dict_key_type_mismatch", - ), ], ) def test_interpolation_type_validated_error( @@ -402,39 +369,52 @@ def test_interpolation_type_validated_error( @mark.parametrize( - ("cfg", "key"), + ("cfg", "key", "expected_value", "expected_node_type"), [ - param({"dict": "${identity:{a: 0, b: 1}}"}, "dict.a", id="dict"), param( - {"dict": "${identity:{a: 0, b: {c: 1}}}"}, - "dict.b.c", - id="dict_nested", + MissingList(list=SI("${oc.create:[0, 1, 2]}")), + "list", + [0, 1, 2], + ListConfig, + id="list_int_to_str", + ), + param( + MissingDict(dict=SI("${oc.create:{a: 0, b: 1}}")), + "dict", + {"a": 0, "b": 1}, + DictConfig, + id="dict_int_to_str", + ), + param( + SubscriptedList(list=SI("${oc.create:[a, b]}")), + "list", + ["a", "b"], + ListConfig, + id="list_type_mismatch", + ), + param( + MissingDict(dict=SI("${oc.create:{0: b, 1: d}}")), + "dict", + {0: "b", 1: "d"}, + DictConfig, + id="dict_key_type_mismatch", ), - param({"list": "${identity:[0, 1]}"}, "list.0", id="list"), - param({"list": "${identity:[0, [1, 2]]}"}, "list.1.1", id="list_nested"), ], ) -def test_interpolation_readonly_resolver_output( - common_resolvers: Any, cfg: Any, key: str -) -> None: - cfg = OmegaConf.create(cfg) - sub_key: Any - parent_key, sub_key = key.rsplit(".", 1) - try: - sub_key = int(sub_key) # convert list index to integer - except ValueError: - pass - parent_node = OmegaConf.select(cfg, parent_key) - assert parent_node._get_flag("readonly") - - -def test_interpolation_readonly_node() -> None: - cfg = OmegaConf.structured(User(name="7", age=II("name"))) - resolved = dereference_node(cfg, "age") - assert resolved == 7 - # The `resolved` node must be read-only because `age` is an integer, so the - # interpolation cannot return directly the `name` node. - assert resolved._get_flag("readonly") +def test_interpolation_type_not_validated( + cfg: Any, + key: str, + expected_value: Any, + expected_node_type: Any, +) -> Any: + cfg = OmegaConf.structured(cfg) + + val = cfg[key] + assert val == expected_value + + node = cfg._get_node(key) + assert isinstance(node, Node) + assert isinstance(node._dereference_node(), expected_node_type) def test_type_validation_error_no_throw() -> None: diff --git a/tests/test_nodes.py b/tests/test_nodes.py index eb2d7af26..b9555a1bb 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -1,4 +1,5 @@ import copy +import functools import re from enum import Enum from typing import Any, Dict, Tuple, Type @@ -596,3 +597,29 @@ def test_dereference_interpolation_to_missing() -> None: assert x_node._maybe_dereference_node() is None with raises(InterpolationToMissingValueError): cfg.x + + +@mark.parametrize( + "flags", + [ + {}, + {"flag": True}, + {"flag": False}, + {"flag1": True, "flag2": False}, + ], +) +@mark.parametrize( + "type_", + [ + AnyNode, + BooleanNode, + functools.partial(EnumNode, enum_type=Color), + FloatNode, + IntegerNode, + StringNode, + ], +) +def test_set_flags_in_init(type_: Any, flags: Dict[str, bool]) -> None: + node = type_(value=None, flags=flags) + for f, v in flags.items(): + assert node._get_flag(f) is v