Skip to content

Commit

Permalink
flag_override can now override multiple flags at the same time
Browse files Browse the repository at this point in the history
  • Loading branch information
omry committed Sep 30, 2020
1 parent c1b7298 commit 5e9de21
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 5 deletions.
1 change: 1 addition & 0 deletions news/400.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
flag_override can now override multiple flags at the same time
15 changes: 11 additions & 4 deletions omegaconf/omegaconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,14 +708,21 @@ def to_yaml(cfg: Any, *, resolve: bool = False, sort_keys: bool = False) -> str:

@contextmanager
def flag_override(
config: Node, name: str, value: Optional[bool]
config: Node, names: Union[List[str], str], value: Optional[bool]
) -> Generator[Node, None, None]:
prev_state = config._get_flag(name)

if isinstance(names, str):
names = [names]

prev_states = [config._get_flag(name) for name in names]

try:
config._set_flag(name, value)
for idx, name in enumerate(names):
config._set_flag(name, value)
yield config
finally:
config._set_flag(name, prev_state)
for idx, name in enumerate(names):
config._set_flag(name, prev_states[idx])


@contextmanager
Expand Down
20 changes: 19 additions & 1 deletion tests/test_base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
open_dict,
read_write,
)
from omegaconf.errors import ConfigKeyError
from omegaconf.errors import ConfigAttributeError, ConfigKeyError

from . import Color, StructuredWithMissing, User, does_not_raise

Expand Down Expand Up @@ -440,6 +440,24 @@ def test_flag_override(
func(c)


def test_multiple_flags_override() -> None:
c = OmegaConf.create({"foo": "bar"})
with flag_override(c, ["readonly"], True):
with pytest.raises(ReadonlyConfigError):
c.foo = 10

with flag_override(c, ["struct"], True):
with pytest.raises(ConfigAttributeError):
c.x = 10

with flag_override(c, ["struct", "readonly"], True):
with pytest.raises(ConfigAttributeError):
c.x = 10

with pytest.raises(ReadonlyConfigError):
c.foo = 20


@pytest.mark.parametrize( # type: ignore
"src, func, expectation",
[
Expand Down

0 comments on commit 5e9de21

Please sign in to comment.