Skip to content

Commit

Permalink
hacky attempt to fix omry#435
Browse files Browse the repository at this point in the history
  • Loading branch information
Jasha10 committed Feb 23, 2021
1 parent e2c60a2 commit a3a2a0e
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 27 deletions.
44 changes: 28 additions & 16 deletions omegaconf/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
ConfigTypeError,
ConfigValueError,
OmegaConfBaseException,
ValidationError,
)
from .grammar_parser import parse

Expand Down Expand Up @@ -195,13 +196,15 @@ def _resolve_forward(type_: Type[Any], module: str) -> Type[Any]:
def get_attr_data(obj: Any, allow_objects: Optional[bool] = None) -> Dict[str, Any]:
from omegaconf.omegaconf import OmegaConf, _maybe_wrap

obj_type = get_type_of(obj)

flags = {"allow_objects": allow_objects} if allow_objects is not None else {}
dummy_parent = OmegaConf.create(flags=flags)
dummy_parent._metadata.object_type = obj_type
from omegaconf import MISSING

d = {}
is_type = isinstance(obj, type)
obj_type = obj if is_type else type(obj)
for name, attrib in attr.fields_dict(obj_type).items():
is_optional, type_ = _resolve_optional(attrib.type)
type_ = _resolve_forward(type_, obj.__module__)
Expand All @@ -217,13 +220,16 @@ def get_attr_data(obj: Any, allow_objects: Optional[bool] = None) -> Dict[str, A
)
format_and_raise(node=None, key=None, value=value, cause=e, msg=str(e))

d[name] = _maybe_wrap(
ref_type=type_,
is_optional=is_optional,
key=name,
value=value,
parent=dummy_parent,
)
try:
d[name] = _maybe_wrap(
ref_type=type_,
is_optional=is_optional,
key=name,
value=value,
parent=dummy_parent,
)
except ValidationError as ex:
dummy_parent._format_and_raise(key=name, value=value, cause=ex)
d[name]._set_parent(None)
return d

Expand All @@ -233,10 +239,13 @@ def get_dataclass_data(
) -> Dict[str, Any]:
from omegaconf.omegaconf import MISSING, OmegaConf, _maybe_wrap

obj_type = get_type_of(obj)

flags = {"allow_objects": allow_objects} if allow_objects is not None else {}
dummy_parent = OmegaConf.create({}, flags=flags)
dummy_parent._metadata.object_type = obj_type
d = {}
resolved_hints = get_type_hints(get_type_of(obj))
resolved_hints = get_type_hints(obj_type)
for field in dataclasses.fields(obj):
name = field.name
is_optional, type_ = _resolve_optional(resolved_hints[field.name])
Expand All @@ -257,13 +266,16 @@ def get_dataclass_data(
f"Union types are not supported:\n{name}: {type_str(type_)}"
)
format_and_raise(node=None, key=None, value=value, cause=e, msg=str(e))
d[name] = _maybe_wrap(
ref_type=type_,
is_optional=is_optional,
key=name,
value=value,
parent=dummy_parent,
)
try:
d[name] = _maybe_wrap(
ref_type=type_,
is_optional=is_optional,
key=name,
value=value,
parent=dummy_parent,
)
except ValidationError as ex:
dummy_parent._format_and_raise(key=name, value=value, cause=ex)
d[name]._set_parent(None)
return d

Expand Down
8 changes: 8 additions & 0 deletions tests/structured_conf/test_structured_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ def test_error_on_non_structured_nested_config_class(
assert list(ret.keys()) == ["bar"]
assert ret.bar == module.NotStructuredConfig()

def test_error_on_creation_with_bad_value_type(self, class_type: str) -> None:
module: Any = import_module(class_type)
with pytest.raises(
ValidationError,
match=re.escape("Value 'seven' could not be converted to Integer"),
):
OmegaConf.structured(module.User(age="seven"))

def test_assignment_of_subclass(self, class_type: str) -> None:
module: Any = import_module(class_type)
cfg = OmegaConf.create({"plugin": module.Plugin})
Expand Down
27 changes: 16 additions & 11 deletions tests/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,30 +474,31 @@ def finalize(self, cfg: Any) -> None:
pytest.param(
Expected(
create=lambda: None,
op=lambda cfg: OmegaConf.structured(NotOptionalInt),
op=lambda _: OmegaConf.structured(NotOptionalInt),
exception_type=ValidationError,
msg="Non optional field cannot be assigned None",
object_type_str=None,
ref_type_str=None,
key="foo",
object_type=NotOptionalInt,
parent_node=lambda _: {}, # dummy parent
),
id="dict:create_none_optional_with_none",
),
pytest.param(
Expected(
create=lambda: None,
op=lambda cfg: OmegaConf.structured(NotOptionalInt),
op=lambda _: OmegaConf.structured(NotOptionalInt),
exception_type=ValidationError,
object_type=None,
object_type=NotOptionalInt,
msg="Non optional field cannot be assigned None",
object_type_str="NotOptionalInt",
ref_type_str=None,
key="foo",
parent_node=lambda _: {}, # dummy parent
),
id="dict:create:not_optional_int_field_with_none",
),
pytest.param(
Expected(
create=lambda: None,
op=lambda cfg: OmegaConf.structured(NotOptionalA),
op=lambda _: OmegaConf.structured(NotOptionalA),
exception_type=ValidationError,
object_type=None,
key=None,
Expand All @@ -511,32 +512,35 @@ def finalize(self, cfg: Any) -> None:
pytest.param(
Expected(
create=lambda: None,
op=lambda cfg: OmegaConf.structured(IllegalType),
op=lambda _: OmegaConf.structured(IllegalType),
exception_type=ValidationError,
msg="Input class 'IllegalType' is not a structured config. did you forget to decorate it as a dataclass?",
object_type_str=None,
ref_type_str=None,
parent_node=lambda _: None,
),
id="dict_create_from_illegal_type",
),
pytest.param(
Expected(
create=lambda: None,
op=lambda cfg: OmegaConf.structured(IllegalType()),
op=lambda _: OmegaConf.structured(IllegalType()),
exception_type=ValidationError,
msg="Object of unsupported type: 'IllegalType'",
object_type_str=None,
ref_type_str=None,
parent_node=lambda _: None,
),
id="structured:create_from_unsupported_object",
),
pytest.param(
Expected(
create=lambda: None,
op=lambda cfg: OmegaConf.structured(UnionError),
op=lambda _: OmegaConf.structured(UnionError),
exception_type=ValueError,
msg="Union types are not supported:\nx: Union[int, str]",
num_lines=3,
parent_node=lambda _: None,
),
id="structured:create_with_union_error",
),
Expand All @@ -549,6 +553,7 @@ def finalize(self, cfg: Any) -> None:
msg="Invalid type assigned : int is not a subclass of ConcretePlugin. value: 1",
low_level=True,
ref_type=Optional[ConcretePlugin],
parent_node=lambda _: {}, # dummy parent
),
id="dict:set_value:reftype_mismatch",
),
Expand Down

0 comments on commit a3a2a0e

Please sign in to comment.