Skip to content

Commit

Permalink
Struct assign (#587)
Browse files Browse the repository at this point in the history
  • Loading branch information
omry authored Mar 10, 2021
2 parents f0128ae + 1ef5521 commit 6661470
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 28 deletions.
1 change: 1 addition & 0 deletions news/586.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Assignment of a dict/list to an existing node in a parent in struct mode no longer raises ValidationError
27 changes: 13 additions & 14 deletions omegaconf/dictconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,26 +629,25 @@ def _set_value_impl(
self.__dict__["_content"] = {}
if is_structured_config(value):
self._metadata.object_type = None
data = get_structured_config_data(
value,
allow_objects=self._get_flag("allow_objects"),
)
for k, v in data.items():
self.__setitem__(k, v)
ao = self._get_flag("allow_objects")
data = get_structured_config_data(value, allow_objects=ao)
with flag_override(self, ["struct", "readonly"], False):
for k, v in data.items():
self.__setitem__(k, v)
self._metadata.object_type = get_type_of(value)

elif isinstance(value, DictConfig):
self.__dict__["_metadata"] = copy.deepcopy(value._metadata)
self._metadata.flags = copy.deepcopy(flags)
# disable struct and readonly for the construction phase
# retaining other flags like allow_objects. The real flags are restored at the end of this function
with flag_override(self, "struct", False):
with flag_override(self, "readonly", False):
for k, v in value.__dict__["_content"].items():
self.__setitem__(k, v)
with flag_override(self, ["struct", "readonly"], False):
for k, v in value.__dict__["_content"].items():
self.__setitem__(k, v)

elif isinstance(value, dict):
for k, v in value.items():
self.__setitem__(k, v)
with flag_override(self, ["struct", "readonly"], False):
for k, v in value.items():
self.__setitem__(k, v)

else: # pragma: no cover
msg = f"Unsupported value type : {value}"
raise ValidationError(msg)
Expand Down
12 changes: 6 additions & 6 deletions omegaconf/listconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,13 +593,13 @@ def _set_value_impl(
self._metadata.flags = copy.deepcopy(flags)
# disable struct and readonly for the construction phase
# retaining other flags like allow_objects. The real flags are restored at the end of this function
with flag_override(self, "struct", False):
with flag_override(self, "readonly", False):
for item in value._iter_ex(resolve=False):
self.append(item)
with flag_override(self, ["struct", "readonly"], False):
for item in value._iter_ex(resolve=False):
self.append(item)
elif is_primitive_list(value):
for item in value:
self.append(item)
with flag_override(self, ["struct", "readonly"], False):
for item in value:
self.append(item)

@staticmethod
def _list_eq(l1: Optional["ListConfig"], l2: Optional["ListConfig"]) -> bool:
Expand Down
9 changes: 9 additions & 0 deletions tests/test_basic_ops_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,3 +1023,12 @@ def test_dict_getitem_not_found() -> None:
def test_dict_getitem_none_output() -> None:
cfg = OmegaConf.create({"a": None})
assert cfg["a"] is None


@pytest.mark.parametrize("data", [{"b": 0}, User])
@pytest.mark.parametrize("flag", ["struct", "readonly"])
def test_dictconfig_creation_with_parent_flag(flag: str, data: Any) -> None:
parent = OmegaConf.create({"a": 10})
parent._set_flag(flag, True)
cfg = DictConfig(data, parent=parent)
assert cfg == data
9 changes: 9 additions & 0 deletions tests/test_basic_ops_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,3 +712,12 @@ def test_shallow_copy_none() -> None:
c._set_value([1])
assert c[0] == 1
assert cfg._is_none()


@pytest.mark.parametrize("flag", ["struct", "readonly"])
def test_listconfig_creation_with_parent_flag(flag: str) -> None:
parent = OmegaConf.create([])
parent._set_flag(flag, True)
d = [1, 2, 3]
cfg = ListConfig(d, parent=parent)
assert cfg == d
6 changes: 6 additions & 0 deletions tests/test_readonly.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@
raises(ReadonlyConfigError, match="a"),
id="dict_setitem",
),
pytest.param(
{"a": None},
lambda c: c.__setitem__("a", {"b": 10}),
raises(ReadonlyConfigError, match="a"),
id="dict_setitem",
),
pytest.param(
{"a": {"b": {"c": 1}}},
lambda c: c.__getattr__("a").__getattr__("b").__setitem__("c", 1),
Expand Down
23 changes: 15 additions & 8 deletions tests/test_struct.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re
from typing import Any, Dict

import pytest
from pytest import mark, raises

from omegaconf import OmegaConf
from omegaconf.errors import ConfigKeyError
Expand All @@ -16,40 +16,40 @@ def test_struct_set_on_dict() -> None:
c = OmegaConf.create({"a": {}})
OmegaConf.set_struct(c, True)
# Throwing when it hits foo, so exception key is a.foo and not a.foo.bar
with pytest.raises(AttributeError, match=re.escape("a.foo")):
with raises(AttributeError, match=re.escape("a.foo")):
# noinspection PyStatementEffect
c.a.foo.bar


def test_struct_set_on_nested_dict() -> None:
c = OmegaConf.create({"a": {"b": 10}})
OmegaConf.set_struct(c, True)
with pytest.raises(AttributeError):
with raises(AttributeError):
# noinspection PyStatementEffect
c.foo

assert "a" in c
assert c.a.b == 10
with pytest.raises(AttributeError, match=re.escape("a.foo")):
with raises(AttributeError, match=re.escape("a.foo")):
# noinspection PyStatementEffect
c.a.foo


def test_merge_dotlist_into_struct() -> None:
c = OmegaConf.create({"a": {"b": 10}})
OmegaConf.set_struct(c, True)
with pytest.raises(AttributeError, match=re.escape("foo")):
with raises(AttributeError, match=re.escape("foo")):
c.merge_with_dotlist(["foo=1"])


@pytest.mark.parametrize("in_base, in_merged", [(dict(), dict(a=10))])
@mark.parametrize("in_base, in_merged", [({}, {"a": 10})])
def test_merge_config_with_struct(
in_base: Dict[str, Any], in_merged: Dict[str, Any]
) -> None:
base = OmegaConf.create(in_base)
merged = OmegaConf.create(in_merged)
OmegaConf.set_struct(base, True)
with pytest.raises(ConfigKeyError):
with raises(ConfigKeyError):
OmegaConf.merge(base, merged)


Expand All @@ -59,6 +59,13 @@ def test_struct_contain_missing() -> None:
assert "foo" not in c


@pytest.mark.parametrize("cfg", [{}, OmegaConf.create({}, flags={"struct": True})])
@mark.parametrize("cfg", [{}, OmegaConf.create({}, flags={"struct": True})])
def test_struct_dict_get(cfg: Any) -> None:
assert cfg.get("z") is None


def test_struct_dict_assign() -> None:
cfg = OmegaConf.create({"a": {}})
OmegaConf.set_struct(cfg, True)
cfg.a = {"b": 10}
assert cfg.a == {"b": 10}

0 comments on commit 6661470

Please sign in to comment.