diff --git a/omegaconf/dictconfig.py b/omegaconf/dictconfig.py index dd6a0c330..5ebe4c003 100644 --- a/omegaconf/dictconfig.py +++ b/omegaconf/dictconfig.py @@ -292,9 +292,7 @@ def _s_validate_and_normalize_key(self, key_type: Any, key: Any) -> DictKeyType: return key # type: ignore elif issubclass(key_type, Enum): try: - ret = EnumNode.validate_and_convert_to_enum( - key_type, key, allow_none=False - ) + ret = EnumNode.validate_and_convert_to_enum(key_type, key) assert ret is not None return ret except ValidationError: diff --git a/omegaconf/nodes.py b/omegaconf/nodes.py index ef4806638..ce39e1387 100644 --- a/omegaconf/nodes.py +++ b/omegaconf/nodes.py @@ -40,16 +40,22 @@ def _set_value(self, value: Any, flags: Optional[Dict[str, bool]] = None) -> Non ): self._val = value else: - if not self._metadata.optional and value is None: - raise ValidationError("Non optional field cannot be assigned None") self._val = self.validate_and_convert(value) def validate_and_convert(self, value: Any) -> Any: """ Validates input and converts to canonical form :param value: input value - :return: converted value ("100" may be converted to 100 for example) + :return: converted value ("100" may be converted to 100 for example) """ + if value is None: + if self._is_optional(): + return None + raise ValidationError("Non optional field cannot be assigned None") + # Subclasses can assume that `value` is not None in `_validate_and_convert_impl()`. + return self._validate_and_convert_impl(value) + + def _validate_and_convert_impl(self, value: Any) -> Any: return value def __str__(self) -> str: @@ -113,17 +119,14 @@ def __init__( value: Any = None, key: Any = None, parent: Optional[Container] = None, - is_optional: bool = True, ): super().__init__( parent=parent, value=value, - metadata=Metadata( - ref_type=Any, object_type=None, key=key, optional=is_optional - ), + metadata=Metadata(ref_type=Any, object_type=None, key=key, optional=True), ) - def validate_and_convert(self, value: Any) -> Any: + def _validate_and_convert_impl(self, value: Any) -> Any: from ._utils import is_primitive_type # allow_objects is internal and not an official API. use at your own risk. @@ -159,12 +162,12 @@ def __init__( ), ) - def validate_and_convert(self, value: Any) -> Optional[str]: + def _validate_and_convert_impl(self, value: Any) -> str: from omegaconf import OmegaConf if OmegaConf.is_config(value) or is_primitive_container(value): raise ValidationError("Cannot convert '$VALUE_TYPE' to string : '$VALUE'") - return str(value) if value is not None else None + return str(value) def __deepcopy__(self, memo: Dict[int, Any]) -> "StringNode": res = StringNode() @@ -188,11 +191,9 @@ def __init__( ), ) - def validate_and_convert(self, value: Any) -> Optional[int]: + def _validate_and_convert_impl(self, value: Any) -> int: try: - if value is None: - val = None - elif type(value) in (str, int): + if type(value) in (str, int): val = int(value) else: raise ValueError() @@ -222,9 +223,7 @@ def __init__( ), ) - def validate_and_convert(self, value: Any) -> Optional[float]: - if value is None: - return None + def _validate_and_convert_impl(self, value: Any) -> float: try: if type(value) in (float, str, int): return float(value) @@ -273,16 +272,14 @@ def __init__( ), ) - def validate_and_convert(self, value: Any) -> Optional[bool]: + def _validate_and_convert_impl(self, value: Any) -> bool: if isinstance(value, bool): return value if isinstance(value, int): return value != 0 - elif value is None: - return None elif isinstance(value, str): try: - return self.validate_and_convert(int(value)) + return self._validate_and_convert_impl(int(value)) except ValueError as e: if value.lower() in ("yes", "y", "on", "true"): return True @@ -335,16 +332,11 @@ def __init__( ), ) - def validate_and_convert(self, value: Any) -> Optional[Enum]: + def _validate_and_convert_impl(self, value: Any) -> Enum: return self.validate_and_convert_to_enum(enum_type=self.enum_type, value=value) @staticmethod - def validate_and_convert_to_enum( - enum_type: Type[Enum], value: Any, allow_none: bool = True - ) -> Optional[Enum]: - if allow_none and value is None: - return None - + def validate_and_convert_to_enum(enum_type: Type[Enum], value: Any) -> Enum: if not isinstance(value, (str, int)) and not isinstance(value, enum_type): raise ValidationError( f"Value $VALUE ($VALUE_TYPE) is not a valid input for {enum_type}" diff --git a/omegaconf/omegaconf.py b/omegaconf/omegaconf.py index d78e2605e..2d4fea323 100644 --- a/omegaconf/omegaconf.py +++ b/omegaconf/omegaconf.py @@ -937,7 +937,7 @@ def _node_wrap( element_type=element_type, ) elif type_ == Any or type_ is None: - node = AnyNode(value=value, key=key, parent=parent, is_optional=is_optional) + node = AnyNode(value=value, key=key, parent=parent) elif issubclass(type_, Enum): node = EnumNode( enum_type=type_, @@ -956,7 +956,7 @@ def _node_wrap( node = StringNode(value=value, key=key, parent=parent, is_optional=is_optional) else: if parent is not None and parent._get_flag("allow_objects") is True: - node = AnyNode(value=value, key=key, parent=parent, is_optional=is_optional) + node = AnyNode(value=value, key=key, parent=parent) else: raise ValidationError(f"Unexpected object type : {type_str(type_)}") return node diff --git a/tests/test_interpolation.py b/tests/test_interpolation.py index 8a912b657..d8058d9cd 100644 --- a/tests/test_interpolation.py +++ b/tests/test_interpolation.py @@ -747,3 +747,9 @@ def fail_if_called(x: Any) -> None: x_node = cfg._get_node("x") assert isinstance(x_node, Node) assert x_node._dereference_node(throw_on_resolution_failure=False) is None + + +def test_none_value_in_quoted_string(restore_resolvers: Any) -> None: + OmegaConf.register_new_resolver("test", lambda x: x) + cfg = OmegaConf.create({"x": "${test:'${missing}'}", "missing": None}) + assert cfg.x == "None" diff --git a/tests/test_nodes.py b/tests/test_nodes.py index f71fcd2dd..365f24080 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -1,4 +1,5 @@ import copy +import re from enum import Enum from typing import Any, Dict, Tuple, Type @@ -487,12 +488,7 @@ def test_deepcopy(obj: Any) -> None: (BooleanNode(True), None, False), (BooleanNode(True), False, False), (BooleanNode(False), False, True), - (AnyNode(value=1, is_optional=True), AnyNode(value=1, is_optional=True), True), - ( - AnyNode(value=1, is_optional=True), - AnyNode(value=1, is_optional=False), - True, - ), + (AnyNode(value=1), AnyNode(value=1), True), (EnumNode(enum_type=Enum1), Enum1.BAR, False), (EnumNode(enum_type=Enum1), EnumNode(Enum1), True), (EnumNode(enum_type=Enum1), "nope", False), @@ -573,6 +569,26 @@ def test_dereference_missing() -> None: assert x_node._dereference_node() is x_node +@pytest.mark.parametrize( + "make_func", + [ + StringNode, + IntegerNode, + FloatNode, + BooleanNode, + lambda val, is_optional: EnumNode( + enum_type=Color, value=val, is_optional=is_optional + ), + ], +) +def test_validate_and_convert_none(make_func: Any) -> None: + node = make_func("???", is_optional=False) + with pytest.raises( + ValidationError, match=re.escape("Non optional field cannot be assigned None") + ): + node.validate_and_convert(None) + + def test_dereference_interpolation_to_missing() -> None: cfg = OmegaConf.create({"x": "${y}", "y": "???"}) x_node = cfg._get_node("x")