Skip to content

Commit

Permalink
Fix crash with "interpolation-like" strings from interpolations
Browse files Browse the repository at this point in the history
Fixes omry#666
  • Loading branch information
odelalleau committed May 11, 2021
1 parent 0b99ead commit 1a5bbf9
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 8 deletions.
3 changes: 3 additions & 0 deletions omegaconf/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,9 @@ def _wrap_interpolation_result(
value=resolved,
key=key,
ref_type=value._metadata.ref_type,
# Since `resolved` was obtained by resolving an interpolation, it cannot
# be itself an interpolation even if may look like one (ex: "${foo}").
can_be_interpolation=False,
)
else:
# Other objects get wrapped into an `AnyNode` with `allow_objects` set
Expand Down
41 changes: 36 additions & 5 deletions omegaconf/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,25 @@ def _value(self) -> Any:
return self._val

def _set_value(self, value: Any, flags: Optional[Dict[str, bool]] = None) -> None:
return self._set_value_impl(value=value, can_be_interpolation=True, flags=flags)

def _set_value_impl(
self,
value: Any,
can_be_interpolation: bool,
flags: Optional[Dict[str, bool]] = None,
) -> None:
if self._get_flag("readonly"):
raise ReadonlyConfigError("Cannot set value of read-only config node")

if isinstance(value, str) and get_value_kind(
value, strict_interpolation_validation=True
) in (
ValueKind.INTERPOLATION,
ValueKind.MANDATORY_MISSING,
if (
can_be_interpolation
and isinstance(value, str)
and get_value_kind(value, strict_interpolation_validation=True)
in (
ValueKind.INTERPOLATION,
ValueKind.MANDATORY_MISSING,
)
):
self._val = value
else:
Expand Down Expand Up @@ -112,7 +123,9 @@ def __init__(
key: Any = None,
parent: Optional[Container] = None,
flags: Optional[Dict[str, bool]] = None,
can_be_interpolation: bool = True,
):
self.can_be_interpolation = can_be_interpolation
super().__init__(
parent=parent,
value=value,
Expand All @@ -121,6 +134,11 @@ def __init__(
),
)

def _set_value(self, value: Any, flags: Optional[Dict[str, bool]] = None) -> None:
return self._set_value_impl(
value=value, can_be_interpolation=self.can_be_interpolation, flags=flags
)

def _validate_and_convert_impl(self, value: Any) -> Any:
from ._utils import is_primitive_type

Expand All @@ -140,6 +158,9 @@ def __deepcopy__(self, memo: Dict[int, Any]) -> "AnyNode":
self._deepcopy_impl(res, memo)
return res

def _is_interpolation(self) -> bool:
return self.can_be_interpolation and super()._is_interpolation()


class StringNode(ValueNode):
def __init__(
Expand All @@ -149,7 +170,9 @@ def __init__(
parent: Optional[Container] = None,
is_optional: bool = True,
flags: Optional[Dict[str, bool]] = None,
can_be_interpolation: bool = True,
):
self.can_be_interpolation = can_be_interpolation
super().__init__(
parent=parent,
value=value,
Expand All @@ -162,6 +185,11 @@ def __init__(
),
)

def _set_value(self, value: Any, flags: Optional[Dict[str, bool]] = None) -> None:
return self._set_value_impl(
value=value, can_be_interpolation=self.can_be_interpolation, flags=flags
)

def _validate_and_convert_impl(self, value: Any) -> str:
from omegaconf import OmegaConf

Expand All @@ -174,6 +202,9 @@ def __deepcopy__(self, memo: Dict[int, Any]) -> "StringNode":
self._deepcopy_impl(res, memo)
return res

def _is_interpolation(self) -> bool:
return self.can_be_interpolation and super()._is_interpolation()


class IntegerNode(ValueNode):
def __init__(
Expand Down
25 changes: 22 additions & 3 deletions omegaconf/omegaconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,6 +952,9 @@ def _node_wrap(
value: Any,
key: Any,
ref_type: Any = Any,
# Flag indicating whether the input value may be considered to be an interpolation.
# It is only used when wrapping a string within an `AnyNode` or `StringNode`.
can_be_interpolation: bool = True,
) -> Node:
node: Node
is_dict = is_primitive_dict(value) or is_dict_annotation(type_)
Expand Down Expand Up @@ -993,7 +996,12 @@ def _node_wrap(
element_type=element_type,
)
elif type_ == Any or type_ is None:
node = AnyNode(value=value, key=key, parent=parent)
node = AnyNode(
value=value,
key=key,
parent=parent,
can_be_interpolation=can_be_interpolation,
)
elif issubclass(type_, Enum):
node = EnumNode(
enum_type=type_,
Expand All @@ -1009,10 +1017,21 @@ def _node_wrap(
elif type_ == bool:
node = BooleanNode(value=value, key=key, parent=parent, is_optional=is_optional)
elif type_ == str:
node = StringNode(value=value, key=key, parent=parent, is_optional=is_optional)
node = StringNode(
value=value,
key=key,
parent=parent,
is_optional=is_optional,
can_be_interpolation=can_be_interpolation,
)
else:
if parent is not None and parent._get_flag("allow_objects") is True:
node = AnyNode(value=value, key=key, parent=parent)
node = AnyNode(
value=value,
key=key,
parent=parent,
can_be_interpolation=can_be_interpolation,
)
else:
raise ValidationError(f"Unexpected object type: {type_str(type_)}")
return node
Expand Down
34 changes: 34 additions & 0 deletions tests/interpolation/test_interpolation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import re
from dataclasses import dataclass
from textwrap import dedent
from typing import Any, Tuple

Expand All @@ -22,6 +23,7 @@
from omegaconf.errors import InterpolationResolutionError as IRE
from omegaconf.errors import InterpolationValidationError
from tests import MissingDict, MissingList, StructuredWithMissing, SubscriptedList, User
from tests.interpolation import dereference_node

# file deepcode ignore CopyPasteError:
# The above comment is a statement to stop DeepCode from raising a warning on
Expand Down Expand Up @@ -463,3 +465,35 @@ def test_circular_interpolation(cfg: Any, key: str, expected: Any) -> None:
OmegaConf.select(cfg, key)
else:
assert OmegaConf.select(cfg, key) == expected


@mark.parametrize("node_type", [None, Any, str])
@mark.parametrize(
"value",
[
param(r"\${foo", id="escaped_interpolation"),
param(r"$${y}", id="string_interpolation"),
# This passes to `oc.decode` the string with characters: '\${foo' which
# is then resolved into: ${foo
param(r"${oc.decode:'\'\\\${foo\''}", id="resolver"),
],
)
def test_interpolation_result_is_not_an_interpolation(
node_type: Any, value: str
) -> None:
if node_type is None:
# Non-structured config.
cfg = OmegaConf.create({"x": value, "y": "{foo"})

else:
# Structured config.

@dataclass
class Config:
x: node_type = value # type: ignore
y: str = "{foo"

cfg = OmegaConf.structured(Config)

assert cfg.x == "${foo"
assert not dereference_node(cfg, "x")._is_interpolation()

0 comments on commit 1a5bbf9

Please sign in to comment.