From 656706817a4f3de1900fe0fc43cb7aed7a74da36 Mon Sep 17 00:00:00 2001 From: Omry Yadan Date: Tue, 6 Apr 2021 17:21:56 -0700 Subject: [PATCH 1/5] Add force_add to OmegaConf.update(), effectively using open_dict for all nodes along the path --- news/664.feature | 1 + omegaconf/omegaconf.py | 58 +++++++++++++++++++++++++----------------- tests/test_update.py | 23 +++++++++++++++++ 3 files changed, 58 insertions(+), 24 deletions(-) create mode 100644 news/664.feature diff --git a/news/664.feature b/news/664.feature new file mode 100644 index 000000000..a4895ccdf --- /dev/null +++ b/news/664.feature @@ -0,0 +1 @@ +OmegaConf.update() accept a new flag force_add (default False), which ensures the config path will be created even if a node along the way is in strut mode. diff --git a/omegaconf/omegaconf.py b/omegaconf/omegaconf.py index cc7714f3e..358c21311 100644 --- a/omegaconf/omegaconf.py +++ b/omegaconf/omegaconf.py @@ -727,7 +727,12 @@ def select( @staticmethod def update( - cfg: Container, key: str, value: Any = None, merge: Optional[bool] = None + cfg: Container, + key: str, + value: Any = None, + *, + merge: Optional[bool] = None, + force_add: bool = False, ) -> None: """ Updates a dot separated key sequence to a value @@ -739,7 +744,8 @@ def update( :param merge: If value is a dict or a list, True for merge, False for set. True to merge False to set - None (default) : deprecation warning and default to False + None (default): deprecation warning and default to False + :param force_add: insert the entire path regardless of Struct flag or Structured Config nodes. """ if merge is None: @@ -757,12 +763,14 @@ def update( split = split_key(key) root = cfg for i in range(len(split) - 1): - k = split[i] - # if next_root is a primitive (string, int etc) replace it with an empty map - next_root, key_ = _select_one(root, k, throw_on_missing=False) - if not isinstance(next_root, Container): - root[key_] = {} - root = root[key_] + struct_override = False if force_add else root._get_node_flag("struct") + with flag_override(root, "struct", struct_override): + k = split[i] + # if next_root is a primitive (string, int etc) replace it with an empty map + next_root, key_ = _select_one(root, k, throw_on_missing=False) + if not isinstance(next_root, Container): + root[key_] = {} + root = root[key_] last = split[-1] @@ -774,22 +782,24 @@ def update( if isinstance(root, ListConfig): last_key = int(last) - if merge and (OmegaConf.is_config(value) or is_primitive_container(value)): - assert isinstance(root, BaseContainer) - node = root._get_node(last_key) - if OmegaConf.is_config(node): - assert isinstance(node, BaseContainer) - node.merge_with(value) - return - - if OmegaConf.is_dict(root): - assert isinstance(last_key, str) - root.__setattr__(last_key, value) - elif OmegaConf.is_list(root): - assert isinstance(last_key, int) - root.__setitem__(last_key, value) - else: - assert False + struct_override = False if force_add else root._get_node_flag("struct") + with flag_override(root, "struct", struct_override): + if merge and (OmegaConf.is_config(value) or is_primitive_container(value)): + assert isinstance(root, BaseContainer) + node = root._get_node(last_key) + if OmegaConf.is_config(node): + assert isinstance(node, BaseContainer) + node.merge_with(value) + return + + if OmegaConf.is_dict(root): + assert isinstance(last_key, str) + root.__setattr__(last_key, value) + elif OmegaConf.is_list(root): + assert isinstance(last_key, int) + root.__setitem__(last_key, value) + else: + assert False @staticmethod def to_yaml(cfg: Any, *, resolve: bool = False, sort_keys: bool = False) -> str: diff --git a/tests/test_update.py b/tests/test_update.py index b34c876c8..dd6278362 100644 --- a/tests/test_update.py +++ b/tests/test_update.py @@ -6,6 +6,7 @@ from omegaconf import ListConfig, OmegaConf, ValidationError from omegaconf._utils import _ensure_container, is_primitive_container +from omegaconf.errors import ConfigAttributeError, ConfigKeyError from tests import Package @@ -195,3 +196,25 @@ def test_merge_deprecation() -> None: with warns(UserWarning, match=re.escape(msg)): OmegaConf.update(cfg, "a", {"c": 20}) # default to set, and issue a warning. assert cfg == {"a": {"c": 20}} + + +@mark.parametrize( + "cfg,key,value,expected", + [ + param({}, "a", 10, {"a": 10}, id="add_value"), + param({}, "a.b", 10, {"a": {"b": 10}}, id="add_value"), + param({}, "a", {"b": 10}, {"a": {"b": 10}}, id="add_dict"), + param({}, "a.b", {"c": 10}, {"a": {"b": {"c": 10}}}, id="add_dict"), + param({}, "a", [1, 2], {"a": [1, 2]}, id="add_list"), + param({}, "a.b", [1, 2], {"a": {"b": [1, 2]}}, id="add_list"), + ], +) +def test_update_force_add(cfg: Any, key: str, value: Any, expected: Any) -> None: + cfg = _ensure_container(cfg) + OmegaConf.set_struct(cfg, True) + + with raises((ConfigAttributeError, ConfigKeyError)): # type: ignore + OmegaConf.update(cfg, key, value, merge=True, force_add=False) + + OmegaConf.update(cfg, key, value, merge=True, force_add=True) + assert cfg == expected From 3d26236d87ad6aa4c58ca9a4e737720c0785a957 Mon Sep 17 00:00:00 2001 From: Omry Yadan Date: Tue, 6 Apr 2021 18:06:02 -0700 Subject: [PATCH 2/5] added a benchmark and updated docs --- benchmark/benchmark.py | 17 +++++++++++++++++ docs/source/usage.rst | 5 +++++ 2 files changed, 22 insertions(+) diff --git a/benchmark/benchmark.py b/benchmark/benchmark.py index 7a1f6d223..b7914482f 100644 --- a/benchmark/benchmark.py +++ b/benchmark/benchmark.py @@ -149,3 +149,20 @@ def test_get_value_kind( def test_is_missing_literal(benchmark: Any) -> None: assert benchmark(_is_missing_literal, "???") + + +@mark.parametrize("force_add", [False, True]) +def test_update_force_add( + large_dict_config: Any, force_add: bool, benchmark: Any +) -> None: + if force_add: + OmegaConf.set_struct(large_dict_config, True) + + benchmark( + OmegaConf.update, + large_dict_config, + "a.a.a.a.a.a.a.a.a.a.a", + 10, + merge=True, + force_add=force_add, + ) diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 16974f917..a5e2fa34b 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -863,6 +863,7 @@ OmegaConf.update() allows you to update values in your config using either a dot The merge flag controls the behavior if the input is a dict or a list. If it's true, those are merged instead of being assigned. +The force_add flag will add the key even if the config or a node along the way a Structured Config or is in struct mode. .. doctest:: @@ -876,6 +877,10 @@ being assigned. >>> # Merge dictionary value (using bracket notation) >>> OmegaConf.update(cfg, "foo[bar]", {"oompa" : 40}, merge=True) >>> assert cfg.foo.bar == {"zonk" : 30, "oompa" : 40} + >>> # force_add ignores nodes in struct mode and updates anyway. + >>> OmegaConf.set_struct(cfg, True) + >>> OmegaConf.update(cfg, "a.b.c.d", 10, merge=True, force_add=True) + >>> assert cfg.a.b.c.d == 10 From 431c095c32709f04ac3647ba12bc95e448ee3876 Mon Sep 17 00:00:00 2001 From: Omry Yadan Date: Tue, 6 Apr 2021 18:25:23 -0700 Subject: [PATCH 3/5] updated news and usage --- docs/source/usage.rst | 2 +- news/664.feature | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/usage.rst b/docs/source/usage.rst index a5e2fa34b..ccf25d605 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -863,7 +863,7 @@ OmegaConf.update() allows you to update values in your config using either a dot The merge flag controls the behavior if the input is a dict or a list. If it's true, those are merged instead of being assigned. -The force_add flag will add the key even if the config or a node along the way a Structured Config or is in struct mode. +The force_add flag ensures a the path is created even if it will result in insertion of new values into struct nodes. .. doctest:: diff --git a/news/664.feature b/news/664.feature index a4895ccdf..79b50b67b 100644 --- a/news/664.feature +++ b/news/664.feature @@ -1 +1 @@ -OmegaConf.update() accept a new flag force_add (default False), which ensures the config path will be created even if a node along the way is in strut mode. +force_add flag added to OmegaConf.update(), ensuring a the path is created even if it will result in insertion of new values into struct nodes. From d7cf7fbe97498dece1e21ec0837068ef8dda9c1b Mon Sep 17 00:00:00 2001 From: Omry Yadan Date: Tue, 6 Apr 2021 19:29:27 -0700 Subject: [PATCH 4/5] feedback --- docs/source/usage.rst | 2 +- news/664.feature | 2 +- tests/test_update.py | 9 ++++++++- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/docs/source/usage.rst b/docs/source/usage.rst index ccf25d605..2b616c119 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -863,7 +863,7 @@ OmegaConf.update() allows you to update values in your config using either a dot The merge flag controls the behavior if the input is a dict or a list. If it's true, those are merged instead of being assigned. -The force_add flag ensures a the path is created even if it will result in insertion of new values into struct nodes. +The force_add flag ensures that the path is created even if it will result in insertion of new values into struct nodes. .. doctest:: diff --git a/news/664.feature b/news/664.feature index 79b50b67b..6709f418b 100644 --- a/news/664.feature +++ b/news/664.feature @@ -1 +1 @@ -force_add flag added to OmegaConf.update(), ensuring a the path is created even if it will result in insertion of new values into struct nodes. +force_add flag added to OmegaConf.update(), ensuring that the path is created even if it will result in insertion of new values into struct nodes. \ No newline at end of file diff --git a/tests/test_update.py b/tests/test_update.py index dd6278362..bf44782fb 100644 --- a/tests/test_update.py +++ b/tests/test_update.py @@ -7,7 +7,7 @@ from omegaconf import ListConfig, OmegaConf, ValidationError from omegaconf._utils import _ensure_container, is_primitive_container from omegaconf.errors import ConfigAttributeError, ConfigKeyError -from tests import Package +from tests import Package, User @mark.parametrize( @@ -207,6 +207,13 @@ def test_merge_deprecation() -> None: param({}, "a.b", {"c": 10}, {"a": {"b": {"c": 10}}}, id="add_dict"), param({}, "a", [1, 2], {"a": [1, 2]}, id="add_list"), param({}, "a.b", [1, 2], {"a": {"b": [1, 2]}}, id="add_list"), + param( + {"user": User(name="Bond", age=7)}, + "user.location", + "London", + {"user": {"name": "Bond", "age": 7, "location": "London"}}, + id="inserting_into_nested_structured_config", + ), ], ) def test_update_force_add(cfg: Any, key: str, value: Any, expected: Any) -> None: From 02a0c6b8cf74da8b8e93f26ec0bae93fc557b099 Mon Sep 17 00:00:00 2001 From: Omry Yadan Date: Wed, 7 Apr 2021 16:01:01 -0700 Subject: [PATCH 5/5] updated benchmark to populate flags cache --- benchmark/benchmark.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/benchmark/benchmark.py b/benchmark/benchmark.py index b7914482f..3d9a2c091 100644 --- a/benchmark/benchmark.py +++ b/benchmark/benchmark.py @@ -158,6 +158,14 @@ def test_update_force_add( if force_add: OmegaConf.set_struct(large_dict_config, True) + def recursive_is_struct(node: Any) -> None: + if OmegaConf.is_config(node): + OmegaConf.is_struct(node) + for val in node.values(): + recursive_is_struct(val) + + recursive_is_struct(large_dict_config) + benchmark( OmegaConf.update, large_dict_config,