Skip to content

Commit

Permalink
Fix bug assigning structured classes of different element_type in dic…
Browse files Browse the repository at this point in the history
…tconfigs #386 (#395)
  • Loading branch information
pereman2 authored Oct 2, 2020
1 parent 5e9de21 commit ea210fb
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 17 deletions.
1 change: 1 addition & 0 deletions news/386.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix bug that allowed instances of Structured Configs to be assigned to DictConfig with different element type.
6 changes: 5 additions & 1 deletion omegaconf/basecontainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,11 @@ def wrap(key: Any, val: Any) -> Node:
target = self._get_node(key)
if target is None:
if is_structured_config(val):
ref_type = OmegaConf.get_type(val)
element_type = self._metadata.element_type
if element_type is Any:
ref_type = OmegaConf.get_type(val)
else:
ref_type = element_type
else:
is_optional = target._is_optional()
ref_type = target._metadata.ref_type
Expand Down
19 changes: 19 additions & 0 deletions omegaconf/dictconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,10 @@ def _validate_set_merge_impl(self, key: Any, value: Any, is_assign: bool) -> Non
if value == "???":
return

if is_assign and isinstance(value, ValueNode) and self._has_element_type():
self._check_assign_value_node(key, value)
return

target: Optional[Node]
if key is None:
target = self
Expand Down Expand Up @@ -228,6 +232,21 @@ def is_typed(c: Any) -> bool:
)
raise ValidationError(msg)

def _check_assign_value_node(self, key: Any, value: Any) -> None:
from omegaconf import OmegaConf

element_type = self._metadata.element_type
value_type = OmegaConf.get_type(value)
if value_type is not Any and not issubclass(value_type, element_type): # type: ignore
msg = (
f"Invalid type assigned : {type_str(value_type)} is not a "
f"subclass of {type_str(element_type)}. value: {value}"
)
raise ValidationError(msg)

def _has_element_type(self) -> bool:
return self._metadata.element_type is not Any

def _validate_and_normalize_key(self, key: Any) -> Union[str, Enum]:
return self._s_validate_and_normalize_key(self._metadata.key_type, key)

Expand Down
2 changes: 1 addition & 1 deletion tests/structured_conf/data/attr_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ class Str2StrWithField(Dict[str, str]):

@attr.s(auto_attribs=True)
class Str2IntWithStrField(Dict[str, int]):
foo: str = "bar"
foo: int = 1

class Error:
@attr.s(auto_attribs=True)
Expand Down
2 changes: 1 addition & 1 deletion tests/structured_conf/data/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ class Str2StrWithField(Dict[str, str]):

@dataclass
class Str2IntWithStrField(Dict[str, int]):
foo: str = "bar"
foo: int = 1

class Error:
@dataclass
Expand Down
32 changes: 18 additions & 14 deletions tests/structured_conf/test_structured_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
_utils,
)
from omegaconf.errors import ConfigKeyError
from tests import Color
from tests import Color, User


class EnumConfigAssignments:
Expand Down Expand Up @@ -734,7 +734,17 @@ def test_assign_wrong_type_to_list(self, class_type: str, value: Any) -> None:
cfg.tuple = value

@pytest.mark.parametrize( # type: ignore
"value", [1, True, "str", 3.1415, ["foo", True, 1.2], {"foo": True}]
"value",
[
1,
True,
"str",
3.1415,
["foo", True, 1.2],
{"foo": True},
User(age=1, name="foo"),
{"user": User(age=1, name="foo")},
],
)
def test_assign_wrong_type_to_dict(self, class_type: str, value: Any) -> None:
module: Any = import_module(class_type)
Expand Down Expand Up @@ -886,24 +896,18 @@ def test_str2str_with_field(self, class_type: str) -> None:
with pytest.raises(KeyValidationError):
cfg[Color.RED] = "fail"

def test_str2int_with_field_of_different_type(self, class_type: str) -> None:
module: Any = import_module(class_type)
cfg = OmegaConf.structured(module.DictSubclass.Str2IntWithStrField())
assert cfg.foo == "bar"

cfg.one = 1
assert cfg.one == 1

with pytest.raises(ValidationError):
# bad
cfg.hello = "world"

class TestErrors:
def test_usr2str(self, class_type: str) -> None:
module: Any = import_module(class_type)
with pytest.raises(KeyValidationError):
OmegaConf.structured(module.DictSubclass.Error.User2Str())

def test_str2int_with_field_of_different_type(self, class_type: str) -> None:
module: Any = import_module(class_type)
cfg = OmegaConf.structured(module.DictSubclass.Str2IntWithStrField())
with pytest.raises(ValidationError):
cfg.foo = "str"

def test_construct_from_another_retain_node_types(self, class_type: str) -> None:
module: Any = import_module(class_type)
cfg1 = OmegaConf.create(module.User(name="James Bond", age=7))
Expand Down

0 comments on commit ea210fb

Please sign in to comment.