diff --git a/news/586.bugfix b/news/586.bugfix new file mode 100644 index 000000000..ffdaf308b --- /dev/null +++ b/news/586.bugfix @@ -0,0 +1,2 @@ +Assignment of a dict to an existing node in a parent in struct mode no longer raises ValidationError + diff --git a/omegaconf/omegaconf.py b/omegaconf/omegaconf.py index 66d8b77ca..f6dc0aeb3 100644 --- a/omegaconf/omegaconf.py +++ b/omegaconf/omegaconf.py @@ -898,6 +898,10 @@ def _node_wrap( ref_type: Any = Any, ) -> Node: node: Node + allow_objects = parent is not None and parent._get_flag("allow_objects") is True + flags = {"allow_objects": allow_objects} if allow_objects is not None else {} + dummy = OmegaConf.create(flags=flags) + is_dict = is_primitive_dict(value) or is_dict_annotation(type_) is_list = ( type(value) in (list, tuple) @@ -909,7 +913,7 @@ def _node_wrap( node = DictConfig( content=value, key=key, - parent=parent, + parent=dummy, ref_type=type_, is_optional=is_optional, key_type=key_type, @@ -920,7 +924,7 @@ def _node_wrap( node = ListConfig( content=value, key=key, - parent=parent, + parent=dummy, is_optional=is_optional, element_type=element_type, ref_type=ref_type, @@ -932,33 +936,34 @@ def _node_wrap( is_optional=is_optional, content=value, key=key, - parent=parent, + parent=dummy, key_type=key_type, element_type=element_type, ) elif type_ == Any or type_ is None: - node = AnyNode(value=value, key=key, parent=parent, is_optional=is_optional) + node = AnyNode(value=value, key=key, parent=dummy, is_optional=is_optional) elif issubclass(type_, Enum): node = EnumNode( enum_type=type_, value=value, key=key, - parent=parent, + parent=dummy, is_optional=is_optional, ) elif type_ == int: - node = IntegerNode(value=value, key=key, parent=parent, is_optional=is_optional) + node = IntegerNode(value=value, key=key, parent=dummy, is_optional=is_optional) elif type_ == float: - node = FloatNode(value=value, key=key, parent=parent, is_optional=is_optional) + node = FloatNode(value=value, key=key, parent=dummy, is_optional=is_optional) elif type_ == bool: - node = BooleanNode(value=value, key=key, parent=parent, is_optional=is_optional) + node = BooleanNode(value=value, key=key, parent=dummy, is_optional=is_optional) elif type_ == str: - node = StringNode(value=value, key=key, parent=parent, is_optional=is_optional) + node = StringNode(value=value, key=key, parent=dummy, is_optional=is_optional) else: - if parent is not None and parent._get_flag("allow_objects") is True: - node = AnyNode(value=value, key=key, parent=parent, is_optional=is_optional) + if allow_objects: + node = AnyNode(value=value, key=key, parent=dummy, is_optional=is_optional) else: raise ValidationError(f"Unexpected object type : {type_str(type_)}") + node._set_parent(parent) return node diff --git a/tests/test_struct.py b/tests/test_struct.py index d15e23406..6bcc9f0e8 100644 --- a/tests/test_struct.py +++ b/tests/test_struct.py @@ -62,3 +62,10 @@ def test_struct_contain_missing() -> None: @mark.parametrize("cfg", [{}, OmegaConf.create({}, flags={"struct": True})]) def test_struct_dict_get(cfg: Any) -> None: assert cfg.get("z") is None + + +def test_struct_dict_assign() -> None: + cfg = OmegaConf.create({"a": {}}) + OmegaConf.set_struct(cfg, True) + cfg.a = {"b": 10} + assert cfg.a == {"b": 10}