diff --git a/news/400.feature b/news/400.feature new file mode 100644 index 000000000..95e5c300b --- /dev/null +++ b/news/400.feature @@ -0,0 +1 @@ +flag_override can now override multiple flags at the same time diff --git a/omegaconf/omegaconf.py b/omegaconf/omegaconf.py index df681b1ab..b063a2ddb 100644 --- a/omegaconf/omegaconf.py +++ b/omegaconf/omegaconf.py @@ -708,14 +708,21 @@ def to_yaml(cfg: Any, *, resolve: bool = False, sort_keys: bool = False) -> str: @contextmanager def flag_override( - config: Node, name: str, value: Optional[bool] + config: Node, names: Union[List[str], str], value: Optional[bool] ) -> Generator[Node, None, None]: - prev_state = config._get_flag(name) + + if isinstance(names, str): + names = [names] + + prev_states = [config._get_flag(name) for name in names] + try: - config._set_flag(name, value) + for idx, name in enumerate(names): + config._set_flag(name, value) yield config finally: - config._set_flag(name, prev_state) + for idx, name in enumerate(names): + config._set_flag(name, prev_states[idx]) @contextmanager diff --git a/tests/test_base_config.py b/tests/test_base_config.py index cd21df7cb..fc343d26f 100644 --- a/tests/test_base_config.py +++ b/tests/test_base_config.py @@ -18,7 +18,7 @@ open_dict, read_write, ) -from omegaconf.errors import ConfigKeyError +from omegaconf.errors import ConfigAttributeError, ConfigKeyError from . import Color, StructuredWithMissing, User, does_not_raise @@ -440,6 +440,24 @@ def test_flag_override( func(c) +def test_multiple_flags_override() -> None: + c = OmegaConf.create({"foo": "bar"}) + with flag_override(c, ["readonly"], True): + with pytest.raises(ReadonlyConfigError): + c.foo = 10 + + with flag_override(c, ["struct"], True): + with pytest.raises(ConfigAttributeError): + c.x = 10 + + with flag_override(c, ["struct", "readonly"], True): + with pytest.raises(ConfigAttributeError): + c.x = 10 + + with pytest.raises(ReadonlyConfigError): + c.foo = 20 + + @pytest.mark.parametrize( # type: ignore "src, func, expectation", [