diff --git a/benchmark/benchmark.py b/benchmark/benchmark.py index 7a1f6d223..3d9a2c091 100644 --- a/benchmark/benchmark.py +++ b/benchmark/benchmark.py @@ -149,3 +149,28 @@ def test_get_value_kind( def test_is_missing_literal(benchmark: Any) -> None: assert benchmark(_is_missing_literal, "???") + + +@mark.parametrize("force_add", [False, True]) +def test_update_force_add( + large_dict_config: Any, force_add: bool, benchmark: Any +) -> None: + if force_add: + OmegaConf.set_struct(large_dict_config, True) + + def recursive_is_struct(node: Any) -> None: + if OmegaConf.is_config(node): + OmegaConf.is_struct(node) + for val in node.values(): + recursive_is_struct(val) + + recursive_is_struct(large_dict_config) + + benchmark( + OmegaConf.update, + large_dict_config, + "a.a.a.a.a.a.a.a.a.a.a", + 10, + merge=True, + force_add=force_add, + ) diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 16974f917..2b616c119 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -863,6 +863,7 @@ OmegaConf.update() allows you to update values in your config using either a dot The merge flag controls the behavior if the input is a dict or a list. If it's true, those are merged instead of being assigned. +The force_add flag ensures that the path is created even if it will result in insertion of new values into struct nodes. .. doctest:: @@ -876,6 +877,10 @@ being assigned. >>> # Merge dictionary value (using bracket notation) >>> OmegaConf.update(cfg, "foo[bar]", {"oompa" : 40}, merge=True) >>> assert cfg.foo.bar == {"zonk" : 30, "oompa" : 40} + >>> # force_add ignores nodes in struct mode and updates anyway. + >>> OmegaConf.set_struct(cfg, True) + >>> OmegaConf.update(cfg, "a.b.c.d", 10, merge=True, force_add=True) + >>> assert cfg.a.b.c.d == 10 diff --git a/news/664.feature b/news/664.feature new file mode 100644 index 000000000..6709f418b --- /dev/null +++ b/news/664.feature @@ -0,0 +1 @@ +force_add flag added to OmegaConf.update(), ensuring that the path is created even if it will result in insertion of new values into struct nodes. \ No newline at end of file diff --git a/omegaconf/omegaconf.py b/omegaconf/omegaconf.py index cc7714f3e..358c21311 100644 --- a/omegaconf/omegaconf.py +++ b/omegaconf/omegaconf.py @@ -727,7 +727,12 @@ def select( @staticmethod def update( - cfg: Container, key: str, value: Any = None, merge: Optional[bool] = None + cfg: Container, + key: str, + value: Any = None, + *, + merge: Optional[bool] = None, + force_add: bool = False, ) -> None: """ Updates a dot separated key sequence to a value @@ -739,7 +744,8 @@ def update( :param merge: If value is a dict or a list, True for merge, False for set. True to merge False to set - None (default) : deprecation warning and default to False + None (default): deprecation warning and default to False + :param force_add: insert the entire path regardless of Struct flag or Structured Config nodes. """ if merge is None: @@ -757,12 +763,14 @@ def update( split = split_key(key) root = cfg for i in range(len(split) - 1): - k = split[i] - # if next_root is a primitive (string, int etc) replace it with an empty map - next_root, key_ = _select_one(root, k, throw_on_missing=False) - if not isinstance(next_root, Container): - root[key_] = {} - root = root[key_] + struct_override = False if force_add else root._get_node_flag("struct") + with flag_override(root, "struct", struct_override): + k = split[i] + # if next_root is a primitive (string, int etc) replace it with an empty map + next_root, key_ = _select_one(root, k, throw_on_missing=False) + if not isinstance(next_root, Container): + root[key_] = {} + root = root[key_] last = split[-1] @@ -774,22 +782,24 @@ def update( if isinstance(root, ListConfig): last_key = int(last) - if merge and (OmegaConf.is_config(value) or is_primitive_container(value)): - assert isinstance(root, BaseContainer) - node = root._get_node(last_key) - if OmegaConf.is_config(node): - assert isinstance(node, BaseContainer) - node.merge_with(value) - return - - if OmegaConf.is_dict(root): - assert isinstance(last_key, str) - root.__setattr__(last_key, value) - elif OmegaConf.is_list(root): - assert isinstance(last_key, int) - root.__setitem__(last_key, value) - else: - assert False + struct_override = False if force_add else root._get_node_flag("struct") + with flag_override(root, "struct", struct_override): + if merge and (OmegaConf.is_config(value) or is_primitive_container(value)): + assert isinstance(root, BaseContainer) + node = root._get_node(last_key) + if OmegaConf.is_config(node): + assert isinstance(node, BaseContainer) + node.merge_with(value) + return + + if OmegaConf.is_dict(root): + assert isinstance(last_key, str) + root.__setattr__(last_key, value) + elif OmegaConf.is_list(root): + assert isinstance(last_key, int) + root.__setitem__(last_key, value) + else: + assert False @staticmethod def to_yaml(cfg: Any, *, resolve: bool = False, sort_keys: bool = False) -> str: diff --git a/tests/test_update.py b/tests/test_update.py index b34c876c8..bf44782fb 100644 --- a/tests/test_update.py +++ b/tests/test_update.py @@ -6,7 +6,8 @@ from omegaconf import ListConfig, OmegaConf, ValidationError from omegaconf._utils import _ensure_container, is_primitive_container -from tests import Package +from omegaconf.errors import ConfigAttributeError, ConfigKeyError +from tests import Package, User @mark.parametrize( @@ -195,3 +196,32 @@ def test_merge_deprecation() -> None: with warns(UserWarning, match=re.escape(msg)): OmegaConf.update(cfg, "a", {"c": 20}) # default to set, and issue a warning. assert cfg == {"a": {"c": 20}} + + +@mark.parametrize( + "cfg,key,value,expected", + [ + param({}, "a", 10, {"a": 10}, id="add_value"), + param({}, "a.b", 10, {"a": {"b": 10}}, id="add_value"), + param({}, "a", {"b": 10}, {"a": {"b": 10}}, id="add_dict"), + param({}, "a.b", {"c": 10}, {"a": {"b": {"c": 10}}}, id="add_dict"), + param({}, "a", [1, 2], {"a": [1, 2]}, id="add_list"), + param({}, "a.b", [1, 2], {"a": {"b": [1, 2]}}, id="add_list"), + param( + {"user": User(name="Bond", age=7)}, + "user.location", + "London", + {"user": {"name": "Bond", "age": 7, "location": "London"}}, + id="inserting_into_nested_structured_config", + ), + ], +) +def test_update_force_add(cfg: Any, key: str, value: Any, expected: Any) -> None: + cfg = _ensure_container(cfg) + OmegaConf.set_struct(cfg, True) + + with raises((ConfigAttributeError, ConfigKeyError)): # type: ignore + OmegaConf.update(cfg, key, value, merge=True, force_add=False) + + OmegaConf.update(cfg, key, value, merge=True, force_add=True) + assert cfg == expected