Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add force_add to OmegaConf.update(), effectively using open_dict for all nodes along the path #665

Merged
merged 5 commits into from
Apr 7, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,20 @@ def test_get_value_kind(

def test_is_missing_literal(benchmark: Any) -> None:
assert benchmark(_is_missing_literal, "???")


@mark.parametrize("force_add", [False, True])
def test_update_force_add(
large_dict_config: Any, force_add: bool, benchmark: Any
) -> None:
Comment on lines +155 to +157
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a benchmark to verify the excessive flag overrides are not causing performance regression.
The concern is that the recursive flags cache invalidation when a flag is set will slow things down.
It does not seem to cause a significant slowdown for a large dict_config here though.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there's too much to worry here because in the worst case it would only invalidate the whole config cache once: after it's cleared, subsequent cache invalidation wouldn't trickle down anymore.
That being said, I'm not sure if large_dict_config actually comes from a cache set on all config nodes, which would be needed to evaluate this worst case scenario. If that's not the case then the benchmark isn't very relevant (at least to evaluate the impact of cache invalidation on performance).

I also suggested an optimization that should reduce the amount of cache invalidation.

Copy link
Owner Author

@omry omry Apr 7, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I didn't notice the gating in _invalidate_flags_cache.
Good point about the benchmark not doing much.

I fixed the benchmark to populate the flags cache on all nodes. We don't have a baseline though.
Overall I am not too concerned about relative performance here though, at 400us per call on the big config this is plenty fast enough.

if force_add:
OmegaConf.set_struct(large_dict_config, True)

benchmark(
OmegaConf.update,
large_dict_config,
"a.a.a.a.a.a.a.a.a.a.a",
10,
merge=True,
force_add=force_add,
)
5 changes: 5 additions & 0 deletions docs/source/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 +863,7 @@ OmegaConf.update() allows you to update values in your config using either a dot

The merge flag controls the behavior if the input is a dict or a list. If it's true, those are merged instead of
being assigned.
The force_add flag will add the key even if the config or a node along the way a Structured Config or is in struct mode.
omry marked this conversation as resolved.
Show resolved Hide resolved

.. doctest::

Expand All @@ -876,6 +877,10 @@ being assigned.
>>> # Merge dictionary value (using bracket notation)
>>> OmegaConf.update(cfg, "foo[bar]", {"oompa" : 40}, merge=True)
>>> assert cfg.foo.bar == {"zonk" : 30, "oompa" : 40}
>>> # force_add ignores nodes in struct mode and updates anyway.
>>> OmegaConf.set_struct(cfg, True)
>>> OmegaConf.update(cfg, "a.b.c.d", 10, merge=True, force_add=True)
>>> assert cfg.a.b.c.d == 10



Expand Down
1 change: 1 addition & 0 deletions news/664.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
OmegaConf.update() accept a new flag force_add (default False), which ensures the config path will be created even if a node along the way is in strut mode.
58 changes: 34 additions & 24 deletions omegaconf/omegaconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,12 @@ def select(

@staticmethod
def update(
cfg: Container, key: str, value: Any = None, merge: Optional[bool] = None
cfg: Container,
key: str,
value: Any = None,
*,
merge: Optional[bool] = None,
odelalleau marked this conversation as resolved.
Show resolved Hide resolved
force_add: bool = False,
) -> None:
"""
Updates a dot separated key sequence to a value
Expand All @@ -739,7 +744,8 @@ def update(
:param merge: If value is a dict or a list, True for merge, False for set.
True to merge
False to set
None (default) : deprecation warning and default to False
None (default): deprecation warning and default to False
:param force_add: insert the entire path regardless of Struct flag or Structured Config nodes.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may be surprising for force_add not to allow adding to readonly configs, is that intended?

(if it's changed, it may be easier to rename it into force so that it can force both adding new values and overriding existing ones, as otherwise you'd need to make the difference between both situations for readonly configs, which seems tricky)

Copy link
Owner Author

@omry omry Apr 7, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reported problem is specifically about adding a field to a Structured Config.
The original implementation in Hydra was using open_dict on the root of the config and this was not sufficient because of Structured Configs are not inheriting the struct flag from their parent.

read-only is used as a mechanism to protect against changes to nodes here.
In contrast to the reported problem, using with read_write(cfg): will work fine because it's behavior us "normal" and in fact frozen support for dataclasses is implemented by setting the recursive read-only flag on the DictConfig node.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I feel like it may be possible to have a version of open_dict() (or, more generally, flag_override()) that behaves as if all children had their flag set to the desired value (which would avoid the issue of intermediate nodes causing trouble because of their local flags). It'd be a bit tricky to get it to work efficiently though (you don't want to actually propagate the desired flag to all children -- only those whose flags are being accessed within the context)...

"""

if merge is None:
Expand All @@ -757,12 +763,14 @@ def update(
split = split_key(key)
root = cfg
for i in range(len(split) - 1):
k = split[i]
# if next_root is a primitive (string, int etc) replace it with an empty map
next_root, key_ = _select_one(root, k, throw_on_missing=False)
if not isinstance(next_root, Container):
root[key_] = {}
root = root[key_]
struct_override = False if force_add else root._get_node_flag("struct")
with flag_override(root, "struct", struct_override):
k = split[i]
# if next_root is a primitive (string, int etc) replace it with an empty map
next_root, key_ = _select_one(root, k, throw_on_missing=False)
if not isinstance(next_root, Container):
root[key_] = {}
root = root[key_]
Comment on lines +766 to +773
Copy link
Collaborator

@odelalleau odelalleau Apr 7, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optimization by only calling flag_override when needed:

Suggested change
struct_override = False if force_add else root._get_node_flag("struct")
with flag_override(root, "struct", struct_override):
k = split[i]
# if next_root is a primitive (string, int etc) replace it with an empty map
next_root, key_ = _select_one(root, k, throw_on_missing=False)
if not isinstance(next_root, Container):
root[key_] = {}
root = root[key_]
k = split[i]
# if next_root is a primitive (string, int etc) replace it with an empty map
next_root, key_ = _select_one(root, k, throw_on_missing=False)
if not isinstance(next_root, Container):
if force_add:
with flag_override(root, "struct", False):
root[key_] = {}
else:
root[key_] = {}
root = root[key_]

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code here will not apply cleanly, (see syntax error in 775).
In fact we do have a does_not_raise context manager similar to nullcontext in the tests. we can promote it to utils.
Feel free to followup with a PR optimizing both places.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, oops, looks like my mouse select skills need improving. I just added the missing bracket.


last = split[-1]

Expand All @@ -774,22 +782,24 @@ def update(
if isinstance(root, ListConfig):
last_key = int(last)

if merge and (OmegaConf.is_config(value) or is_primitive_container(value)):
assert isinstance(root, BaseContainer)
node = root._get_node(last_key)
if OmegaConf.is_config(node):
assert isinstance(node, BaseContainer)
node.merge_with(value)
return

if OmegaConf.is_dict(root):
assert isinstance(last_key, str)
root.__setattr__(last_key, value)
elif OmegaConf.is_list(root):
assert isinstance(last_key, int)
root.__setitem__(last_key, value)
else:
assert False
struct_override = False if force_add else root._get_node_flag("struct")
with flag_override(root, "struct", struct_override):
Comment on lines +785 to +786
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A similar optimization could be done to avoid overriding for no reason. Something like:

Suggested change
struct_override = False if force_add else root._get_node_flag("struct")
with flag_override(root, "struct", struct_override):
ctx = flag_override(root, "struct", False) if force_add else nullcontext()
with ctx:

However nullcontext was only aldded in 3.7 so it would require writing our own implementation for 3.6 in _utils.py.

if merge and (OmegaConf.is_config(value) or is_primitive_container(value)):
assert isinstance(root, BaseContainer)
node = root._get_node(last_key)
if OmegaConf.is_config(node):
assert isinstance(node, BaseContainer)
node.merge_with(value)
return

if OmegaConf.is_dict(root):
assert isinstance(last_key, str)
root.__setattr__(last_key, value)
elif OmegaConf.is_list(root):
assert isinstance(last_key, int)
root.__setitem__(last_key, value)
else:
assert False

@staticmethod
def to_yaml(cfg: Any, *, resolve: bool = False, sort_keys: bool = False) -> str:
Expand Down
23 changes: 23 additions & 0 deletions tests/test_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from omegaconf import ListConfig, OmegaConf, ValidationError
from omegaconf._utils import _ensure_container, is_primitive_container
from omegaconf.errors import ConfigAttributeError, ConfigKeyError
from tests import Package


Expand Down Expand Up @@ -195,3 +196,25 @@ def test_merge_deprecation() -> None:
with warns(UserWarning, match=re.escape(msg)):
OmegaConf.update(cfg, "a", {"c": 20}) # default to set, and issue a warning.
assert cfg == {"a": {"c": 20}}


@mark.parametrize(
"cfg,key,value,expected",
[
param({}, "a", 10, {"a": 10}, id="add_value"),
param({}, "a.b", 10, {"a": {"b": 10}}, id="add_value"),
param({}, "a", {"b": 10}, {"a": {"b": 10}}, id="add_dict"),
param({}, "a.b", {"c": 10}, {"a": {"b": {"c": 10}}}, id="add_dict"),
param({}, "a", [1, 2], {"a": [1, 2]}, id="add_list"),
param({}, "a.b", [1, 2], {"a": {"b": [1, 2]}}, id="add_list"),
],
)
def test_update_force_add(cfg: Any, key: str, value: Any, expected: Any) -> None:
cfg = _ensure_container(cfg)
OmegaConf.set_struct(cfg, True)
odelalleau marked this conversation as resolved.
Show resolved Hide resolved

with raises((ConfigAttributeError, ConfigKeyError)): # type: ignore
OmegaConf.update(cfg, key, value, merge=True, force_add=False)

OmegaConf.update(cfg, key, value, merge=True, force_add=True)
assert cfg == expected