From 1a5bbf9b42af82c7722174b82f17aceca6a1614f Mon Sep 17 00:00:00 2001 From: Olivier Delalleau <507137+odelalleau@users.noreply.github.com> Date: Fri, 7 May 2021 09:22:30 -0400 Subject: [PATCH] Fix crash with "interpolation-like" strings from interpolations Fixes #666 --- omegaconf/base.py | 3 ++ omegaconf/nodes.py | 41 ++++++++++++++++++++--- omegaconf/omegaconf.py | 25 ++++++++++++-- tests/interpolation/test_interpolation.py | 34 +++++++++++++++++++ 4 files changed, 95 insertions(+), 8 deletions(-) diff --git a/omegaconf/base.py b/omegaconf/base.py index 859f80245..fbc4ce86d 100644 --- a/omegaconf/base.py +++ b/omegaconf/base.py @@ -568,6 +568,9 @@ def _wrap_interpolation_result( value=resolved, key=key, ref_type=value._metadata.ref_type, + # Since `resolved` was obtained by resolving an interpolation, it cannot + # be itself an interpolation even if may look like one (ex: "${foo}"). + can_be_interpolation=False, ) else: # Other objects get wrapped into an `AnyNode` with `allow_objects` set diff --git a/omegaconf/nodes.py b/omegaconf/nodes.py index 400f531f9..ab63655c2 100644 --- a/omegaconf/nodes.py +++ b/omegaconf/nodes.py @@ -30,14 +30,25 @@ def _value(self) -> Any: return self._val def _set_value(self, value: Any, flags: Optional[Dict[str, bool]] = None) -> None: + return self._set_value_impl(value=value, can_be_interpolation=True, flags=flags) + + def _set_value_impl( + self, + value: Any, + can_be_interpolation: bool, + flags: Optional[Dict[str, bool]] = None, + ) -> None: if self._get_flag("readonly"): raise ReadonlyConfigError("Cannot set value of read-only config node") - if isinstance(value, str) and get_value_kind( - value, strict_interpolation_validation=True - ) in ( - ValueKind.INTERPOLATION, - ValueKind.MANDATORY_MISSING, + if ( + can_be_interpolation + and isinstance(value, str) + and get_value_kind(value, strict_interpolation_validation=True) + in ( + ValueKind.INTERPOLATION, + ValueKind.MANDATORY_MISSING, + ) ): self._val = value else: @@ -112,7 +123,9 @@ def __init__( key: Any = None, parent: Optional[Container] = None, flags: Optional[Dict[str, bool]] = None, + can_be_interpolation: bool = True, ): + self.can_be_interpolation = can_be_interpolation super().__init__( parent=parent, value=value, @@ -121,6 +134,11 @@ def __init__( ), ) + def _set_value(self, value: Any, flags: Optional[Dict[str, bool]] = None) -> None: + return self._set_value_impl( + value=value, can_be_interpolation=self.can_be_interpolation, flags=flags + ) + def _validate_and_convert_impl(self, value: Any) -> Any: from ._utils import is_primitive_type @@ -140,6 +158,9 @@ def __deepcopy__(self, memo: Dict[int, Any]) -> "AnyNode": self._deepcopy_impl(res, memo) return res + def _is_interpolation(self) -> bool: + return self.can_be_interpolation and super()._is_interpolation() + class StringNode(ValueNode): def __init__( @@ -149,7 +170,9 @@ def __init__( parent: Optional[Container] = None, is_optional: bool = True, flags: Optional[Dict[str, bool]] = None, + can_be_interpolation: bool = True, ): + self.can_be_interpolation = can_be_interpolation super().__init__( parent=parent, value=value, @@ -162,6 +185,11 @@ def __init__( ), ) + def _set_value(self, value: Any, flags: Optional[Dict[str, bool]] = None) -> None: + return self._set_value_impl( + value=value, can_be_interpolation=self.can_be_interpolation, flags=flags + ) + def _validate_and_convert_impl(self, value: Any) -> str: from omegaconf import OmegaConf @@ -174,6 +202,9 @@ def __deepcopy__(self, memo: Dict[int, Any]) -> "StringNode": self._deepcopy_impl(res, memo) return res + def _is_interpolation(self) -> bool: + return self.can_be_interpolation and super()._is_interpolation() + class IntegerNode(ValueNode): def __init__( diff --git a/omegaconf/omegaconf.py b/omegaconf/omegaconf.py index 27fc9fd5e..949ca93c8 100644 --- a/omegaconf/omegaconf.py +++ b/omegaconf/omegaconf.py @@ -952,6 +952,9 @@ def _node_wrap( value: Any, key: Any, ref_type: Any = Any, + # Flag indicating whether the input value may be considered to be an interpolation. + # It is only used when wrapping a string within an `AnyNode` or `StringNode`. + can_be_interpolation: bool = True, ) -> Node: node: Node is_dict = is_primitive_dict(value) or is_dict_annotation(type_) @@ -993,7 +996,12 @@ def _node_wrap( element_type=element_type, ) elif type_ == Any or type_ is None: - node = AnyNode(value=value, key=key, parent=parent) + node = AnyNode( + value=value, + key=key, + parent=parent, + can_be_interpolation=can_be_interpolation, + ) elif issubclass(type_, Enum): node = EnumNode( enum_type=type_, @@ -1009,10 +1017,21 @@ def _node_wrap( elif type_ == bool: node = BooleanNode(value=value, key=key, parent=parent, is_optional=is_optional) elif type_ == str: - node = StringNode(value=value, key=key, parent=parent, is_optional=is_optional) + node = StringNode( + value=value, + key=key, + parent=parent, + is_optional=is_optional, + can_be_interpolation=can_be_interpolation, + ) else: if parent is not None and parent._get_flag("allow_objects") is True: - node = AnyNode(value=value, key=key, parent=parent) + node = AnyNode( + value=value, + key=key, + parent=parent, + can_be_interpolation=can_be_interpolation, + ) else: raise ValidationError(f"Unexpected object type: {type_str(type_)}") return node diff --git a/tests/interpolation/test_interpolation.py b/tests/interpolation/test_interpolation.py index eb3b5807b..8d983cef8 100644 --- a/tests/interpolation/test_interpolation.py +++ b/tests/interpolation/test_interpolation.py @@ -1,5 +1,6 @@ import copy import re +from dataclasses import dataclass from textwrap import dedent from typing import Any, Tuple @@ -22,6 +23,7 @@ from omegaconf.errors import InterpolationResolutionError as IRE from omegaconf.errors import InterpolationValidationError from tests import MissingDict, MissingList, StructuredWithMissing, SubscriptedList, User +from tests.interpolation import dereference_node # file deepcode ignore CopyPasteError: # The above comment is a statement to stop DeepCode from raising a warning on @@ -463,3 +465,35 @@ def test_circular_interpolation(cfg: Any, key: str, expected: Any) -> None: OmegaConf.select(cfg, key) else: assert OmegaConf.select(cfg, key) == expected + + +@mark.parametrize("node_type", [None, Any, str]) +@mark.parametrize( + "value", + [ + param(r"\${foo", id="escaped_interpolation"), + param(r"$${y}", id="string_interpolation"), + # This passes to `oc.decode` the string with characters: '\${foo' which + # is then resolved into: ${foo + param(r"${oc.decode:'\'\\\${foo\''}", id="resolver"), + ], +) +def test_interpolation_result_is_not_an_interpolation( + node_type: Any, value: str +) -> None: + if node_type is None: + # Non-structured config. + cfg = OmegaConf.create({"x": value, "y": "{foo"}) + + else: + # Structured config. + + @dataclass + class Config: + x: node_type = value # type: ignore + y: str = "{foo" + + cfg = OmegaConf.structured(Config) + + assert cfg.x == "${foo" + assert not dereference_node(cfg, "x")._is_interpolation()