From fe6b207da0198aa84aee7ac018f5a98de77fab1f Mon Sep 17 00:00:00 2001 From: Olivier Delalleau <507137+odelalleau@users.noreply.github.com> Date: Wed, 12 May 2021 15:14:20 -0400 Subject: [PATCH] Fix crash with "interpolation-like" strings from interpolations (#709) Fix crash with "interpolation-like" strings from interpolations This commit introduces a new node type `InterpolationResultNode` that systematically wraps interpolation results that either (a) are not already nodes, or (b) need to be converted. Fixes #666 --- omegaconf/base.py | 46 +---------------- omegaconf/nodes.py | 43 +++++++++++++++- tests/interpolation/test_custom_resolvers.py | 10 ++-- tests/interpolation/test_interpolation.py | 53 +++++++++++++++++--- tests/test_nodes.py | 45 ++++++++++++++++- 5 files changed, 140 insertions(+), 57 deletions(-) diff --git a/omegaconf/base.py b/omegaconf/base.py index 859f80245..0b44138e2 100644 --- a/omegaconf/base.py +++ b/omegaconf/base.py @@ -15,7 +15,6 @@ _is_missing_value, format_and_raise, get_value_kind, - is_primitive_type, split_key, ) from .errors import ( @@ -504,7 +503,7 @@ def _validate_and_convert_interpolation_result( resolved: Any, throw_on_resolution_failure: bool, ) -> Optional["Node"]: - from .nodes import AnyNode, ValueNode + from .nodes import AnyNode, InterpolationResultNode, ValueNode # If the output is not a Node already (e.g., because it is the output of a # custom resolver), then we will need to wrap it within a Node. @@ -533,52 +532,11 @@ def _validate_and_convert_interpolation_result( resolved = conv_value if must_wrap: - return self._wrap_interpolation_result( - parent=parent, - value=value, - key=key, - resolved=resolved, - throw_on_resolution_failure=throw_on_resolution_failure, - ) + return InterpolationResultNode(value=resolved, key=key, parent=parent) else: assert isinstance(resolved, Node) return resolved - def _wrap_interpolation_result( - self, - parent: Optional["Container"], - value: Node, - key: Any, - resolved: Any, - throw_on_resolution_failure: bool, - ) -> Optional["Node"]: - from .basecontainer import BaseContainer - from .nodes import AnyNode - from .omegaconf import _node_wrap - - assert parent is None or isinstance(parent, BaseContainer) - - if is_primitive_type(type(resolved)): - # Primitive types get wrapped using `_node_wrap()`, ensuring value is - # validated and potentially converted. - wrapped = _node_wrap( - type_=value._metadata.ref_type, - parent=parent, - is_optional=value._metadata.optional, - value=resolved, - key=key, - ref_type=value._metadata.ref_type, - ) - else: - # Other objects get wrapped into an `AnyNode` with `allow_objects` set - # to True. - wrapped = AnyNode( - value=resolved, key=key, parent=None, flags={"allow_objects": True} - ) - wrapped._set_parent(parent) - - return wrapped - def _validate_not_dereferencing_to_parent(self, node: Node, target: Node) -> None: parent: Optional[Node] = node while parent is not None: diff --git a/omegaconf/nodes.py b/omegaconf/nodes.py index 400f531f9..d0f5af81e 100644 --- a/omegaconf/nodes.py +++ b/omegaconf/nodes.py @@ -24,7 +24,7 @@ def __init__(self, parent: Optional[Container], value: Any, metadata: Metadata): super().__init__(parent=parent, metadata=metadata) with read_write(self): - self._set_value(value) + self._set_value(value) # lgtm [py/init-calls-subclass] def _value(self) -> Any: return self._val @@ -390,3 +390,44 @@ def __deepcopy__(self, memo: Dict[int, Any]) -> "EnumNode": res = EnumNode(enum_type=self.enum_type) self._deepcopy_impl(res, memo) return res + + +class InterpolationResultNode(ValueNode): + """ + Special node type, used to wrap interpolation results. + """ + + def __init__( + self, + value: Any, + key: Any = None, + parent: Optional[Container] = None, + flags: Optional[Dict[str, bool]] = None, + ): + super().__init__( + parent=parent, + value=value, + metadata=Metadata( + ref_type=Any, object_type=None, key=key, optional=True, flags=flags + ), + ) + # In general we should not try to write into interpolation results. + if flags is None or "readonly" not in flags: + self._set_flag("readonly", True) + + def _set_value(self, value: Any, flags: Optional[Dict[str, bool]] = None) -> None: + if self._get_flag("readonly"): + raise ReadonlyConfigError("Cannot set value of read-only config node") + self._val = self.validate_and_convert(value) + + def _validate_and_convert_impl(self, value: Any) -> Any: + # Interpolation results may be anything. + return value + + def __deepcopy__(self, memo: Dict[int, Any]) -> "InterpolationResultNode": + # Currently there should be no need to deep-copy such nodes. + raise NotImplementedError + + def _is_interpolation(self) -> bool: + # The result of an interpolation cannot be itself an interpolation. + return False diff --git a/tests/interpolation/test_custom_resolvers.py b/tests/interpolation/test_custom_resolvers.py index 15292f555..fddce82bb 100644 --- a/tests/interpolation/test_custom_resolvers.py +++ b/tests/interpolation/test_custom_resolvers.py @@ -5,7 +5,7 @@ from pytest import mark, param, raises, warns from omegaconf import OmegaConf, Resolver -from omegaconf.nodes import AnyNode +from omegaconf.nodes import InterpolationResultNode from tests.interpolation import dereference_node @@ -355,8 +355,8 @@ def test_resolver_output_dict(restore_resolvers: Any, readonly: bool) -> None: assert isinstance(c.x, dict) assert c.x == some_dict x_node = dereference_node(c, "x") - assert isinstance(x_node, AnyNode) - assert x_node._get_flag("allow_objects") + assert isinstance(x_node, InterpolationResultNode) + assert x_node._get_flag("readonly") @mark.parametrize("readonly", [True, False]) @@ -378,8 +378,8 @@ def test_resolver_output_plain_dict_list( assert c.x == data x_node = dereference_node(c, "x") - assert isinstance(x_node, AnyNode) - assert x_node._get_flag("allow_objects") + assert isinstance(x_node, InterpolationResultNode) + assert x_node._get_flag("readonly") def test_register_cached_resolver_with_keyword_unsupported() -> None: diff --git a/tests/interpolation/test_interpolation.py b/tests/interpolation/test_interpolation.py index eb3b5807b..e769fb5c4 100644 --- a/tests/interpolation/test_interpolation.py +++ b/tests/interpolation/test_interpolation.py @@ -8,20 +8,24 @@ from omegaconf import ( II, SI, + AnyNode, Container, DictConfig, IntegerNode, ListConfig, Node, OmegaConf, + StringNode, ValidationError, ) from omegaconf._utils import _ensure_container from omegaconf.errors import InterpolationKeyError from omegaconf.errors import InterpolationResolutionError from omegaconf.errors import InterpolationResolutionError as IRE -from omegaconf.errors import InterpolationValidationError +from omegaconf.errors import InterpolationValidationError, ReadonlyConfigError +from omegaconf.nodes import InterpolationResultNode from tests import MissingDict, MissingList, StructuredWithMissing, SubscriptedList, User +from tests.interpolation import dereference_node # file deepcode ignore CopyPasteError: # The above comment is a statement to stop DeepCode from raising a warning on @@ -257,7 +261,7 @@ def test_none_value_in_quoted_string(restore_resolvers: Any) -> None: User(name="Bond", age=SI("${cast:int,'7'}")), "age", 7, - IntegerNode, + InterpolationResultNode, id="expected_type", ), param( @@ -266,7 +270,7 @@ def test_none_value_in_quoted_string(restore_resolvers: Any) -> None: User(name="Bond", age=SI("${cast:int,${drop_last:${drop_last:7xx}}}")), "age", 7, - IntegerNode, + InterpolationResultNode, id="intermediate_type_mismatch_ok", ), param( @@ -275,20 +279,20 @@ def test_none_value_in_quoted_string(restore_resolvers: Any) -> None: User(name="Bond", age=SI("${cast:str,'7'}")), "age", 7, - IntegerNode, + InterpolationResultNode, id="convert_str_to_int", ), param( MissingList(list=SI("${oc.create:[a, b, c]}")), "list", - ["a", "b", "c"], + ListConfig(["a", "b", "c"]), ListConfig, id="list_str", ), param( MissingDict(dict=SI("${oc.create:{key1: val1, key2: val2}}")), "dict", - {"key1": "val1", "key2": "val2"}, + DictConfig({"key1": "val1", "key2": "val2"}), DictConfig, id="dict_str", ), @@ -310,6 +314,7 @@ def drop_last(s: str) -> str: val = cfg[key] assert val == expected_value + assert type(val) is type(expected_value) node = cfg._get_node(key) assert isinstance(node, Node) @@ -463,3 +468,39 @@ def test_circular_interpolation(cfg: Any, key: str, expected: Any) -> None: OmegaConf.select(cfg, key) else: assert OmegaConf.select(cfg, key) == expected + + +@mark.parametrize( + "node_type", + [ + param(lambda x: x, id="untyped"), + param(AnyNode, id="any"), + param(StringNode, id="str"), + ], +) +@mark.parametrize( + ("value", "expected"), + [ + param(r"\${foo}", "${foo}", id="escaped_interpolation_1"), + param(r"\${foo", "${foo", id="escaped_interpolation_2"), + param(r"$${y1}", "${foo}", id="string_interpolation_1"), + param(r"$${y2}", "${foo", id="string_interpolation_2"), + # This passes to `oc.decode` the string with characters: '\${foo}' which + # is then resolved into: ${foo} + param(r"${oc.decode:'\'\\\${foo}\''}", "${foo}", id="resolver_1"), + param(r"${oc.decode:'\'\\\${foo\''}", "${foo", id="resolver_2"), + ], +) +def test_interpolation_like_result_is_not_an_interpolation( + node_type: Any, value: str, expected: str +) -> None: + cfg = OmegaConf.create({"x": node_type(value), "y1": "{foo}", "y2": "{foo"}) + assert cfg.x == expected + + # Check that the resulting node is not considered to be an interpolation. + resolved_node = dereference_node(cfg, "x") + assert not resolved_node._is_interpolation() + + # Check that the resulting node is read-only. + with raises(ReadonlyConfigError): + resolved_node._set_value("foo") diff --git a/tests/test_nodes.py b/tests/test_nodes.py index b9555a1bb..0c9a7a834 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -24,6 +24,7 @@ UnsupportedValueType, ValidationError, ) +from omegaconf.nodes import InterpolationResultNode from tests import Color, IllegalType, User @@ -499,6 +500,18 @@ def test_deepcopy(obj: Any) -> None: True, ), (EnumNode(enum_type=Enum1, value=Enum1.BAR), Enum1.BAR, True), + (InterpolationResultNode("foo"), "foo", True), + (InterpolationResultNode("${foo}"), "${foo}", True), + (InterpolationResultNode("${foo"), "${foo", True), + (InterpolationResultNode(None), None, True), + (InterpolationResultNode(1), 1, True), + (InterpolationResultNode(1.0), 1.0, True), + (InterpolationResultNode(True), True, True), + (InterpolationResultNode(Color.RED), Color.RED, True), + (InterpolationResultNode({"a": 0, "b": 1}), {"a": 0, "b": 1}, True), + (InterpolationResultNode([0, None, True]), [0, None, True], True), + (InterpolationResultNode("foo"), 100, False), + (InterpolationResultNode(100), "foo", False), ], ) def test_eq(node: ValueNode, value: Any, expected: Any) -> None: @@ -506,7 +519,10 @@ def test_eq(node: ValueNode, value: Any, expected: Any) -> None: assert (node != value) != expected assert (value == node) == expected assert (value != node) != expected - assert (node.__hash__() == value.__hash__()) == expected + + # Check hash except for unhashable types (dict/list). + if not isinstance(value, (dict, list)): + assert (node.__hash__() == value.__hash__()) == expected @mark.parametrize("value", [1, 3.14, True, None, Enum1.FOO]) @@ -616,6 +632,7 @@ def test_dereference_interpolation_to_missing() -> None: functools.partial(EnumNode, enum_type=Color), FloatNode, IntegerNode, + InterpolationResultNode, StringNode, ], ) @@ -623,3 +640,29 @@ def test_set_flags_in_init(type_: Any, flags: Dict[str, bool]) -> None: node = type_(value=None, flags=flags) for f, v in flags.items(): assert node._get_flag(f) is v + + +@mark.parametrize( + "flags", + [ + None, + {"flag": True}, + {"flag": False}, + {"readonly": True}, + {"readonly": False}, + {"flag1": True, "flag2": False, "readonly": False}, + {"flag1": False, "flag2": True, "readonly": True}, + ], +) +def test_interpolation_result_readonly(flags: Any) -> None: + readonly = None if flags is None else flags.get("readonly") + expected = [] if flags is None else list(flags.items()) + node = InterpolationResultNode("foo", flags=flags) + + # Check that flags are set to their desired value. + for k, v in expected: + assert node._get_node_flag(k) is v + + # If no value was provided for the "readonly" flag, it should be set. + if readonly is None: + assert node._get_node_flag("readonly")