Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved handling of None values in node validation #592

Merged
merged 8 commits into from
Mar 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions omegaconf/dictconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,9 +292,7 @@ def _s_validate_and_normalize_key(self, key_type: Any, key: Any) -> DictKeyType:
return key # type: ignore
elif issubclass(key_type, Enum):
try:
ret = EnumNode.validate_and_convert_to_enum(
key_type, key, allow_none=False
)
ret = EnumNode.validate_and_convert_to_enum(key_type, key)
assert ret is not None
return ret
except ValidationError:
Expand Down
48 changes: 20 additions & 28 deletions omegaconf/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,22 @@ def _set_value(self, value: Any, flags: Optional[Dict[str, bool]] = None) -> Non
):
self._val = value
else:
if not self._metadata.optional and value is None:
raise ValidationError("Non optional field cannot be assigned None")
self._val = self.validate_and_convert(value)

def validate_and_convert(self, value: Any) -> Any:
"""
Validates input and converts to canonical form
:param value: input value
:return: converted value ("100" may be converted to 100 for example)
:return: converted value ("100" may be converted to 100 for example)
"""
if value is None:
if self._is_optional():
return None
raise ValidationError("Non optional field cannot be assigned None")
# Subclasses can assume that `value` is not None in `_validate_and_convert_impl()`.
return self._validate_and_convert_impl(value)

def _validate_and_convert_impl(self, value: Any) -> Any:
return value

def __str__(self) -> str:
Expand Down Expand Up @@ -113,17 +119,14 @@ def __init__(
value: Any = None,
key: Any = None,
parent: Optional[Container] = None,
is_optional: bool = True,
):
super().__init__(
parent=parent,
value=value,
metadata=Metadata(
ref_type=Any, object_type=None, key=key, optional=is_optional
),
metadata=Metadata(ref_type=Any, object_type=None, key=key, optional=True),
)

def validate_and_convert(self, value: Any) -> Any:
def _validate_and_convert_impl(self, value: Any) -> Any:
from ._utils import is_primitive_type

# allow_objects is internal and not an official API. use at your own risk.
Expand Down Expand Up @@ -159,12 +162,12 @@ def __init__(
),
)

def validate_and_convert(self, value: Any) -> Optional[str]:
def _validate_and_convert_impl(self, value: Any) -> str:
from omegaconf import OmegaConf

if OmegaConf.is_config(value) or is_primitive_container(value):
raise ValidationError("Cannot convert '$VALUE_TYPE' to string : '$VALUE'")
return str(value) if value is not None else None
return str(value)

def __deepcopy__(self, memo: Dict[int, Any]) -> "StringNode":
res = StringNode()
Expand All @@ -188,11 +191,9 @@ def __init__(
),
)

def validate_and_convert(self, value: Any) -> Optional[int]:
def _validate_and_convert_impl(self, value: Any) -> int:
try:
if value is None:
val = None
elif type(value) in (str, int):
if type(value) in (str, int):
val = int(value)
else:
raise ValueError()
Expand Down Expand Up @@ -222,9 +223,7 @@ def __init__(
),
)

def validate_and_convert(self, value: Any) -> Optional[float]:
if value is None:
return None
def _validate_and_convert_impl(self, value: Any) -> float:
try:
if type(value) in (float, str, int):
return float(value)
Expand Down Expand Up @@ -273,16 +272,14 @@ def __init__(
),
)

def validate_and_convert(self, value: Any) -> Optional[bool]:
def _validate_and_convert_impl(self, value: Any) -> bool:
if isinstance(value, bool):
return value
if isinstance(value, int):
return value != 0
elif value is None:
return None
elif isinstance(value, str):
try:
return self.validate_and_convert(int(value))
return self._validate_and_convert_impl(int(value))
except ValueError as e:
if value.lower() in ("yes", "y", "on", "true"):
return True
Expand Down Expand Up @@ -335,16 +332,11 @@ def __init__(
),
)

def validate_and_convert(self, value: Any) -> Optional[Enum]:
def _validate_and_convert_impl(self, value: Any) -> Enum:
return self.validate_and_convert_to_enum(enum_type=self.enum_type, value=value)

@staticmethod
def validate_and_convert_to_enum(
enum_type: Type[Enum], value: Any, allow_none: bool = True
) -> Optional[Enum]:
if allow_none and value is None:
return None

def validate_and_convert_to_enum(enum_type: Type[Enum], value: Any) -> Enum:
if not isinstance(value, (str, int)) and not isinstance(value, enum_type):
raise ValidationError(
f"Value $VALUE ($VALUE_TYPE) is not a valid input for {enum_type}"
Expand Down
4 changes: 2 additions & 2 deletions omegaconf/omegaconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,7 +937,7 @@ def _node_wrap(
element_type=element_type,
)
elif type_ == Any or type_ is None:
node = AnyNode(value=value, key=key, parent=parent, is_optional=is_optional)
node = AnyNode(value=value, key=key, parent=parent)
elif issubclass(type_, Enum):
node = EnumNode(
enum_type=type_,
Expand All @@ -956,7 +956,7 @@ def _node_wrap(
node = StringNode(value=value, key=key, parent=parent, is_optional=is_optional)
else:
if parent is not None and parent._get_flag("allow_objects") is True:
node = AnyNode(value=value, key=key, parent=parent, is_optional=is_optional)
node = AnyNode(value=value, key=key, parent=parent)
else:
raise ValidationError(f"Unexpected object type : {type_str(type_)}")
return node
Expand Down
6 changes: 6 additions & 0 deletions tests/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,3 +747,9 @@ def fail_if_called(x: Any) -> None:
x_node = cfg._get_node("x")
assert isinstance(x_node, Node)
assert x_node._dereference_node(throw_on_resolution_failure=False) is None


def test_none_value_in_quoted_string(restore_resolvers: Any) -> None:
OmegaConf.register_new_resolver("test", lambda x: x)
cfg = OmegaConf.create({"x": "${test:'${missing}'}", "missing": None})
assert cfg.x == "None"
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, so quoting an interpolation is equivalent to casting the resolved value to a string?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. To clarify, there are two situations:

  1. The interpolation is exactly surrounded by quotes ("${...}") => we resolve it normally but cast the result to string.
  2. There are other characters around the interpolation ("hello ${...} and hi") => this is strictly equivalent to regular string interpolations

I think this is the most intuitive behavior.

28 changes: 22 additions & 6 deletions tests/test_nodes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import re
from enum import Enum
from typing import Any, Dict, Tuple, Type

Expand Down Expand Up @@ -487,12 +488,7 @@ def test_deepcopy(obj: Any) -> None:
(BooleanNode(True), None, False),
(BooleanNode(True), False, False),
(BooleanNode(False), False, True),
(AnyNode(value=1, is_optional=True), AnyNode(value=1, is_optional=True), True),
(
AnyNode(value=1, is_optional=True),
AnyNode(value=1, is_optional=False),
True,
),
(AnyNode(value=1), AnyNode(value=1), True),
(EnumNode(enum_type=Enum1), Enum1.BAR, False),
(EnumNode(enum_type=Enum1), EnumNode(Enum1), True),
(EnumNode(enum_type=Enum1), "nope", False),
Expand Down Expand Up @@ -573,6 +569,26 @@ def test_dereference_missing() -> None:
assert x_node._dereference_node() is x_node


@pytest.mark.parametrize(
"make_func",
[
StringNode,
IntegerNode,
FloatNode,
BooleanNode,
lambda val, is_optional: EnumNode(
enum_type=Color, value=val, is_optional=is_optional
),
],
)
def test_validate_and_convert_none(make_func: Any) -> None:
node = make_func("???", is_optional=False)
with pytest.raises(
ValidationError, match=re.escape("Non optional field cannot be assigned None")
):
node.validate_and_convert(None)


def test_dereference_interpolation_to_missing() -> None:
cfg = OmegaConf.create({"x": "${y}", "y": "???"})
x_node = cfg._get_node("x")
Expand Down