From a3a2a0e7cc057f9f1e85e44283b88d318fd8b8d4 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Mon, 22 Feb 2021 19:31:58 -0600 Subject: [PATCH] hacky attempt to fix #435 --- omegaconf/_utils.py | 44 ++++++++++++------- .../structured_conf/test_structured_basic.py | 8 ++++ tests/test_errors.py | 27 +++++++----- 3 files changed, 52 insertions(+), 27 deletions(-) diff --git a/omegaconf/_utils.py b/omegaconf/_utils.py index 05b002d5e..e4b8f510a 100644 --- a/omegaconf/_utils.py +++ b/omegaconf/_utils.py @@ -15,6 +15,7 @@ ConfigTypeError, ConfigValueError, OmegaConfBaseException, + ValidationError, ) from .grammar_parser import parse @@ -195,13 +196,15 @@ def _resolve_forward(type_: Type[Any], module: str) -> Type[Any]: def get_attr_data(obj: Any, allow_objects: Optional[bool] = None) -> Dict[str, Any]: from omegaconf.omegaconf import OmegaConf, _maybe_wrap + obj_type = get_type_of(obj) + flags = {"allow_objects": allow_objects} if allow_objects is not None else {} dummy_parent = OmegaConf.create(flags=flags) + dummy_parent._metadata.object_type = obj_type from omegaconf import MISSING d = {} is_type = isinstance(obj, type) - obj_type = obj if is_type else type(obj) for name, attrib in attr.fields_dict(obj_type).items(): is_optional, type_ = _resolve_optional(attrib.type) type_ = _resolve_forward(type_, obj.__module__) @@ -217,13 +220,16 @@ def get_attr_data(obj: Any, allow_objects: Optional[bool] = None) -> Dict[str, A ) format_and_raise(node=None, key=None, value=value, cause=e, msg=str(e)) - d[name] = _maybe_wrap( - ref_type=type_, - is_optional=is_optional, - key=name, - value=value, - parent=dummy_parent, - ) + try: + d[name] = _maybe_wrap( + ref_type=type_, + is_optional=is_optional, + key=name, + value=value, + parent=dummy_parent, + ) + except ValidationError as ex: + dummy_parent._format_and_raise(key=name, value=value, cause=ex) d[name]._set_parent(None) return d @@ -233,10 +239,13 @@ def get_dataclass_data( ) -> Dict[str, Any]: from omegaconf.omegaconf import MISSING, OmegaConf, _maybe_wrap + obj_type = get_type_of(obj) + flags = {"allow_objects": allow_objects} if allow_objects is not None else {} dummy_parent = OmegaConf.create({}, flags=flags) + dummy_parent._metadata.object_type = obj_type d = {} - resolved_hints = get_type_hints(get_type_of(obj)) + resolved_hints = get_type_hints(obj_type) for field in dataclasses.fields(obj): name = field.name is_optional, type_ = _resolve_optional(resolved_hints[field.name]) @@ -257,13 +266,16 @@ def get_dataclass_data( f"Union types are not supported:\n{name}: {type_str(type_)}" ) format_and_raise(node=None, key=None, value=value, cause=e, msg=str(e)) - d[name] = _maybe_wrap( - ref_type=type_, - is_optional=is_optional, - key=name, - value=value, - parent=dummy_parent, - ) + try: + d[name] = _maybe_wrap( + ref_type=type_, + is_optional=is_optional, + key=name, + value=value, + parent=dummy_parent, + ) + except ValidationError as ex: + dummy_parent._format_and_raise(key=name, value=value, cause=ex) d[name]._set_parent(None) return d diff --git a/tests/structured_conf/test_structured_basic.py b/tests/structured_conf/test_structured_basic.py index 0fef79a90..9f61dd475 100644 --- a/tests/structured_conf/test_structured_basic.py +++ b/tests/structured_conf/test_structured_basic.py @@ -48,6 +48,14 @@ def test_error_on_non_structured_nested_config_class( assert list(ret.keys()) == ["bar"] assert ret.bar == module.NotStructuredConfig() + def test_error_on_creation_with_bad_value_type(self, class_type: str) -> None: + module: Any = import_module(class_type) + with pytest.raises( + ValidationError, + match=re.escape("Value 'seven' could not be converted to Integer"), + ): + OmegaConf.structured(module.User(age="seven")) + def test_assignment_of_subclass(self, class_type: str) -> None: module: Any = import_module(class_type) cfg = OmegaConf.create({"plugin": module.Plugin}) diff --git a/tests/test_errors.py b/tests/test_errors.py index 2d0c4d9bb..6c9e3850d 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -474,30 +474,31 @@ def finalize(self, cfg: Any) -> None: pytest.param( Expected( create=lambda: None, - op=lambda cfg: OmegaConf.structured(NotOptionalInt), + op=lambda _: OmegaConf.structured(NotOptionalInt), exception_type=ValidationError, msg="Non optional field cannot be assigned None", - object_type_str=None, - ref_type_str=None, + key="foo", + object_type=NotOptionalInt, + parent_node=lambda _: {}, # dummy parent ), id="dict:create_none_optional_with_none", ), pytest.param( Expected( create=lambda: None, - op=lambda cfg: OmegaConf.structured(NotOptionalInt), + op=lambda _: OmegaConf.structured(NotOptionalInt), exception_type=ValidationError, - object_type=None, + object_type=NotOptionalInt, msg="Non optional field cannot be assigned None", - object_type_str="NotOptionalInt", - ref_type_str=None, + key="foo", + parent_node=lambda _: {}, # dummy parent ), id="dict:create:not_optional_int_field_with_none", ), pytest.param( Expected( create=lambda: None, - op=lambda cfg: OmegaConf.structured(NotOptionalA), + op=lambda _: OmegaConf.structured(NotOptionalA), exception_type=ValidationError, object_type=None, key=None, @@ -511,32 +512,35 @@ def finalize(self, cfg: Any) -> None: pytest.param( Expected( create=lambda: None, - op=lambda cfg: OmegaConf.structured(IllegalType), + op=lambda _: OmegaConf.structured(IllegalType), exception_type=ValidationError, msg="Input class 'IllegalType' is not a structured config. did you forget to decorate it as a dataclass?", object_type_str=None, ref_type_str=None, + parent_node=lambda _: None, ), id="dict_create_from_illegal_type", ), pytest.param( Expected( create=lambda: None, - op=lambda cfg: OmegaConf.structured(IllegalType()), + op=lambda _: OmegaConf.structured(IllegalType()), exception_type=ValidationError, msg="Object of unsupported type: 'IllegalType'", object_type_str=None, ref_type_str=None, + parent_node=lambda _: None, ), id="structured:create_from_unsupported_object", ), pytest.param( Expected( create=lambda: None, - op=lambda cfg: OmegaConf.structured(UnionError), + op=lambda _: OmegaConf.structured(UnionError), exception_type=ValueError, msg="Union types are not supported:\nx: Union[int, str]", num_lines=3, + parent_node=lambda _: None, ), id="structured:create_with_union_error", ), @@ -549,6 +553,7 @@ def finalize(self, cfg: Any) -> None: msg="Invalid type assigned : int is not a subclass of ConcretePlugin. value: 1", low_level=True, ref_type=Optional[ConcretePlugin], + parent_node=lambda _: {}, # dummy parent ), id="dict:set_value:reftype_mismatch", ),