From ff6eb43f5dcd8d7458dba90df3848963a85ceb41 Mon Sep 17 00:00:00 2001 From: Omry Yadan Date: Tue, 19 May 2020 14:04:07 -0700 Subject: [PATCH] CLI add and delete support for defaults and config --- examples/notebook/%run_test.ipynb | 4 +- .../notebook/hydra_notebook_example.ipynb | 4 +- hydra/_internal/config_loader_impl.py | 233 ++++++++++--- hydra/core/utils.py | 26 -- .../configs/completion_test/config.yaml | 4 +- news/598.api_change | 1 + news/598.feature | 1 + .../hydra_ax_sweeper/tests/config/config.yaml | 2 + tests/test_compose.py | 12 +- tests/test_config_loader.py | 321 +++++++++++++++--- .../test_structured_configs_tutorial.py | 7 +- tests/test_examples/test_tutorials_basic.py | 2 +- tests/test_hydra.py | 10 +- 13 files changed, 489 insertions(+), 138 deletions(-) create mode 100644 news/598.api_change create mode 100644 news/598.feature diff --git a/examples/notebook/%run_test.ipynb b/examples/notebook/%run_test.ipynb index 77536883ee5..d94be156b37 100644 --- a/examples/notebook/%run_test.ipynb +++ b/examples/notebook/%run_test.ipynb @@ -12,7 +12,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -44,7 +44,7 @@ } ], "source": [ - "%run ../tutorials/basic/5_composition/my_app.py" + "%run ../tutorials/basic/your_first_hydra_app/6_putting_it_all_together/my_app.py" ] } ], diff --git a/examples/notebook/hydra_notebook_example.ipynb b/examples/notebook/hydra_notebook_example.ipynb index 670c431ce78..15fe3541c8f 100644 --- a/examples/notebook/hydra_notebook_example.ipynb +++ b/examples/notebook/hydra_notebook_example.ipynb @@ -91,7 +91,7 @@ "source": [ "# Compose a config from scratch\n", "# Missing actual user and password as those are environment specific\n", - "cfg=hydra.experimental.compose(overrides=[\"db=sqlite\"])\n", + "cfg=hydra.experimental.compose(overrides=[\"+db=sqlite\"])\n", "print(cfg.pretty())" ] }, @@ -117,7 +117,7 @@ ], "source": [ "# compose application specific config (in this case the applicaiton is \"donkey\")\n", - "cfg=hydra.experimental.compose(overrides= [\"db=mysql\", \"environment=production\", \"application=donkey\"])\n", + "cfg=hydra.experimental.compose(overrides= [\"+db=mysql\", \"+environment=production\", \"+application=donkey\"])\n", "print(cfg.pretty())" ] }, diff --git a/hydra/_internal/config_loader_impl.py b/hydra/_internal/config_loader_impl.py index f234482cf16..5379e8622f4 100644 --- a/hydra/_internal/config_loader_impl.py +++ b/hydra/_internal/config_loader_impl.py @@ -9,18 +9,32 @@ from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple -from omegaconf import DictConfig, ListConfig, OmegaConf, open_dict +import yaml +from omegaconf import DictConfig, ListConfig, OmegaConf, _utils, open_dict from omegaconf.errors import OmegaConfBaseException from hydra._internal.config_repository import ConfigRepository from hydra.core.config_loader import ConfigLoader, LoadTrace from hydra.core.config_search_path import ConfigSearchPath from hydra.core.object_type import ObjectType -from hydra.core.utils import JobRuntime, get_overrides_dirname +from hydra.core.utils import JobRuntime from hydra.errors import HydraException, MissingConfigException from hydra.plugins.config_source import ConfigLoadError, ConfigSource +@dataclass +class ParsedConfigOverride: + prefix: Optional[str] + key: str + value: str + + def is_delete(self) -> bool: + return self.prefix == "~" + + def is_add(self) -> bool: + return self.prefix == "+" + + @dataclass class ParsedOverride: prefix: Optional[str] @@ -29,14 +43,27 @@ class ParsedOverride: pkg2: Optional[str] value: Optional[str] + def get_source_package(self) -> Optional[str]: + return self.pkg1 + def get_subject_package(self) -> Optional[str]: return self.pkg1 if self.pkg2 is None else self.pkg2 + def get_source_item(self) -> str: + pkg = self.get_source_package() + if pkg is None: + return self.key + else: + return f"{self.key}@{pkg}" + def is_package_rename(self) -> bool: return self.pkg1 is not None and self.pkg2 is not None def is_delete(self) -> bool: - return self.value == "null" + return self.prefix == "~" or self.value == "null" + + def is_add(self) -> bool: + return self.prefix == "+" @dataclass @@ -92,14 +119,14 @@ def __init__( def split_overrides( self, pairs: List[ParsedOverrideWithLine], - ) -> Tuple[List[ParsedOverride], List[ParsedOverrideWithLine]]: + ) -> Tuple[List[ParsedOverrideWithLine], List[ParsedOverrideWithLine]]: config_group_overrides = [] config_overrides = [] for pwd in pairs: if not self.repository.exists(pwd.override.key): config_overrides.append(pwd) else: - config_group_overrides.append(pwd.override) + config_group_overrides.append(pwd) return config_group_overrides, config_overrides def load_configuration( @@ -165,29 +192,26 @@ def load_configuration( OmegaConf.set_struct(cfg, strict) # Merge all command line overrides after enabling strict flag - try: - lst = [x.input_line for x in config_overrides] - merged = OmegaConf.merge(cfg, OmegaConf.from_dotlist(lst)) - assert isinstance(merged, DictConfig) - cfg = merged - except OmegaConfBaseException as ex: - raise HydraException("Error merging overrides") from ex + ConfigLoaderImpl._apply_overrides_to_config( + [x.input_line for x in config_overrides], cfg + ) + app_overrides = [] for pwl in parsed_overrides: override = pwl.override - assert override.value is not None assert override.key is not None key = override.key if key.startswith("hydra.") or key.startswith("hydra/"): cfg.hydra.overrides.hydra.append(pwl.input_line) else: cfg.hydra.overrides.task.append(pwl.input_line) + app_overrides.append(pwl) with open_dict(cfg.hydra.job): if "name" not in cfg.hydra.job: cfg.hydra.job.name = JobRuntime().get("name") cfg.hydra.job.override_dirname = get_overrides_dirname( - input_list=cfg.hydra.overrides.task, + input_list=app_overrides, kv_sep=cfg.hydra.job.config.override_dirname.kv_sep, item_sep=cfg.hydra.job.config.override_dirname.item_sep, exclude_keys=cfg.hydra.job.config.override_dirname.exclude_keys, @@ -255,7 +279,7 @@ def find_matches( @staticmethod def _apply_overrides_to_defaults( - overrides: List[ParsedOverride], defaults: List[DefaultElement] + overrides: List[ParsedOverrideWithLine], defaults: List[DefaultElement], ) -> None: key_to_defaults: Dict[str, List[IndexedDefaultElement]] = defaultdict(list) @@ -265,35 +289,70 @@ def _apply_overrides_to_defaults( key_to_defaults[default.config_group].append( IndexedDefaultElement(idx=idx, default=default) ) - for override in overrides: + for owl in overrides: + override = owl.override + if override.is_add() and override.is_package_rename(): + raise HydraException( + "Add syntax does not support package rename, remove + prefix" + ) + if override.value is not None and "," in override.value: # If this is a multirun config (comma separated list), flag the default to prevent it from being # loaded until we are constructing the config for individual jobs. override.value = "_SKIP_" - if override.value == "null": - matches = ConfigLoaderImpl.find_matches(key_to_defaults, override) + matches = ConfigLoaderImpl.find_matches(key_to_defaults, override) + + if override.is_delete(): + src = override.get_source_item() + if len(matches) == 0: + raise HydraException( + f"Could not delete. No match for '{src}' in the defaults list." + ) for pair in matches: + if ( + override.value is not None + and override.value != defaults[pair.idx].config_name + ): + raise HydraException( + f"Could not delete. No match for '{src}={override.value}' in the defaults list." + ) + del defaults[pair.idx] + elif override.is_add(): + if len(matches) > 0: + src = override.get_source_item() + raise HydraException( + f"Could not add. An item matching '{src}' is already in the defaults list." + ) + assert override.value is not None + defaults.append( + DefaultElement( + config_group=override.key, + config_name=override.value, + package=override.get_subject_package(), + ) + ) else: - matches = ConfigLoaderImpl.find_matches(key_to_defaults, override) - + # override for match in matches: default = match.default if override.value is not None: default.config_name = override.value if override.pkg1 is not None: default.package = override.get_subject_package() - if len(matches) == 0 and not ( - override.is_package_rename() or override.is_delete() - ): - defaults.append( - DefaultElement( - config_group=override.key, - config_name=override.value, - package=override.get_subject_package(), + + if len(matches) == 0: + src = override.get_source_item() + if override.is_package_rename(): + msg = f"Could not rename package. No match for '{src}' in the defaults list." + else: + msg = ( + f"Could not override. No match for '{src}' in the defaults list." + f"\nTo append to your default list, prefix the override with plus. e.g +{owl.input_line}" ) - ) + + raise HydraException(msg) @staticmethod def _split_group(group_with_package: str) -> Tuple[str, Optional[str]]: @@ -309,36 +368,116 @@ def _split_group(group_with_package: str) -> Tuple[str, Optional[str]]: return group, package + @staticmethod + def _apply_overrides_to_config(overrides: List[str], cfg: DictConfig) -> None: + loader = _utils.get_yaml_loader() + + def get_value(val: str) -> Any: + return yaml.load(val, Loader=loader) if override.value is not None else None + + for line in overrides: + override = ConfigLoaderImpl._parse_config_override(line) + try: + value = get_value(override.value) + if override.is_delete(): + val = OmegaConf.select(cfg, override.key, throw_on_missing=False) + if val is None: + raise HydraException( + f"Could not delete from config. '{override.key}' does not exist." + ) + elif value is not None and value != val: + raise HydraException( + f"Could not delete from config." + f" The value of '{override.key}' is {val} and not {override.value}." + ) + + key = override.key + last_dot = key.rfind(".") + if last_dot == -1: + del cfg[key] + else: + node = OmegaConf.select(cfg, key[0:last_dot]) + del node[key[last_dot + 1 :]] + + elif override.is_add(): + if ( + OmegaConf.select(cfg, override.key, throw_on_missing=False) + is None + ): + with open_dict(cfg): + OmegaConf.update(cfg, override.key, value) + else: + raise HydraException( + f"Could not append to config. An item is already at '{override.key}'." + ) + else: + OmegaConf.update(cfg, override.key, value) + except OmegaConfBaseException as ex: + raise HydraException(f"Error merging override {line}") from ex + @staticmethod def _parse_override(override: str) -> ParsedOverrideWithLine: # forms: # key=value # key@pkg=value # key@src_pkg:dst_pkg=value - # regex code and tests: https://regex101.com/r/LiV6Rf/13 + # regex code and tests: https://regex101.com/r/LiV6Rf/14 regex = ( - r"^(?P[+-])?(?P[A-Za-z0-9_.-/]+)(?:@(?P[A-Za-z0-9_\.-]*)" + r"^(?P[+~])?(?P[A-Za-z0-9_.-/]+)(?:@(?P[A-Za-z0-9_\.-]*)" r"(?::(?P[A-Za-z0-9_\.-]*)?)?)?(?:=(?P.*))?$" ) matches = re.search(regex, override) - + msg = ( + f"Error parsing command line override : '{override}'" + f"\nAccepted forms:" + f"\n" + f"\n\tOverride:\tkey=value, key@package=value, key@src_pkg:dest_pkg=value, key@src_pkg:dest_pkg" + f"\n\tAppend:\t+key=value, +key@package=value" + f"\n\tDelete:\t~key@pkg, key@pkg=null" + ) if matches: prefix = matches.group("prefix") key = matches.group("key") pkg1 = matches.group("pkg1") pkg2 = matches.group("pkg2") - value = matches.group("value") + value: Optional[str] = matches.group("value") + if value == "null": + if prefix not in (None, "~"): + raise HydraException(msg) + prefix = "~" + value = None ret = ParsedOverride(prefix, key, pkg1, pkg2, value) return ParsedOverrideWithLine(override=ret, input_line=override) else: - raise HydraException( - f"Error parsing command line override : '{override}'\n" - f"Accepted forms:\n" - f"\tkey=value\n" - f"\tkey@dest_pkg=value\n" - f"\tkey@src_pkg:dest_pkg=value" + raise HydraException(msg) + + @staticmethod + def _parse_config_override(override: str) -> ParsedConfigOverride: + # forms: + # update: key=value + # append: +key=value + # delete: ~key=value | ~key + # regex code and tests: https://regex101.com/r/JAPVdx/4 + + regex = r"^(?P[+~])?(?P.*?)(=(?P.*))?$" + matches = re.search(regex, override) + if matches: + return ParsedConfigOverride( + prefix=matches.group("prefix"), + key=matches.group("key"), + value=matches.group("value"), + ) + else: + msg = ( + f"Error parsing config override : '{override}'" + f"\nAccepted forms:" + f"\n" + f"\n\tOverride:\tkey=value" + f"\n\tAppend:\t+key=value" + f"\n\tDelete:\t~key@pkg" ) + raise HydraException(msg) @staticmethod def _combine_default_lists( @@ -645,3 +784,19 @@ def _parse_defaults(cfg: DictConfig) -> List[DefaultElement]: def get_sources(self) -> List[ConfigSource]: return self.repository.get_sources() + + +def get_overrides_dirname( + input_list: List[ParsedOverrideWithLine], + exclude_keys: List[str] = [], + item_sep: str = ",", + kv_sep: str = "=", +) -> str: + lines = [] + for x in input_list: + if x.override.key not in exclude_keys: + lines.append(x.input_line) + + lines.sort() + ret = re.sub(pattern="[=]", repl=kv_sep, string=item_sep.join(lines)) + return ret diff --git a/hydra/core/utils.py b/hydra/core/utils.py index 5cc8892481d..2488ed052af 100644 --- a/hydra/core/utils.py +++ b/hydra/core/utils.py @@ -60,23 +60,6 @@ def _save_config(cfg: DictConfig, filename: str, output_dir: Path) -> None: file.write(cfg.pretty()) -def get_overrides_dirname( - input_list: Sequence[str], - exclude_keys: Sequence[str] = [], - item_sep: str = ",", - kv_sep: str = "=", -) -> str: - lst = [] - for x in input_list: - key, _val = split_key_val(x) - if key not in exclude_keys: - lst.append(x) - - lst.sort() - ret = re.sub(pattern="[=]", repl=kv_sep, string=item_sep.join(lst)) - return ret - - def filter_overrides(overrides: Sequence[str]) -> Sequence[str]: """ :param overrides: overrides list @@ -85,15 +68,6 @@ def filter_overrides(overrides: Sequence[str]) -> Sequence[str]: return [x for x in overrides if not x.startswith("hydra.")] -def split_key_val(s: str) -> Tuple[str, str]: - if "=" not in s: - raise ValueError(f"'{s}' not a valid override, expecting key=value format") - - idx = s.find("=") - assert idx != -1 - return s[0:idx], s[idx + 1 :] - - def run_job( config: DictConfig, task_function: TaskFunction, diff --git a/hydra/test_utils/configs/completion_test/config.yaml b/hydra/test_utils/configs/completion_test/config.yaml index 8f3f87fa2ea..18625995cbb 100644 --- a/hydra/test_utils/configs/completion_test/config.yaml +++ b/hydra/test_utils/configs/completion_test/config.yaml @@ -1,4 +1,6 @@ -# @package _global_ +defaults: + - group: null + # a mapping item dict: key1: val1 diff --git a/news/598.api_change b/news/598.api_change new file mode 100644 index 00000000000..c772b596c7c --- /dev/null +++ b/news/598.api_change @@ -0,0 +1 @@ +Appending config groups to the defaults list via the command line now requires a + prefix \ No newline at end of file diff --git a/news/598.feature b/news/598.feature new file mode 100644 index 00000000000..a694550ed58 --- /dev/null +++ b/news/598.feature @@ -0,0 +1 @@ +Changes command line processing (requiring + and ~ prefixes for appending and removing items) \ No newline at end of file diff --git a/plugins/hydra_ax_sweeper/tests/config/config.yaml b/plugins/hydra_ax_sweeper/tests/config/config.yaml index 65b28e470c6..ce10f8cac41 100644 --- a/plugins/hydra_ax_sweeper/tests/config/config.yaml +++ b/plugins/hydra_ax_sweeper/tests/config/config.yaml @@ -1,6 +1,8 @@ # @package _group_ defaults: - hydra/sweeper: ax + - quadratic: null + - params: null quadratic: # To be minimized diff --git a/tests/test_compose.py b/tests/test_compose.py index a687675eb7f..c8ecaca8fc0 100644 --- a/tests/test_compose.py +++ b/tests/test_compose.py @@ -137,7 +137,7 @@ def test_strict_failure_disabled_on_call( (None, [], {}), ( None, - ["db=sqlite"], + ["+db=sqlite"], { "db": { "driver": "sqlite", @@ -149,12 +149,12 @@ def test_strict_failure_disabled_on_call( ), ( None, - ["db=mysql", "environment=production"], + ["+db=mysql", "+environment=production"], {"db": {"driver": "mysql", "user": "mysql", "pass": "r4Zn*jQ9JB1Rz2kfz"}}, ), ( None, - ["db=mysql", "environment=production", "application=donkey"], + ["+db=mysql", "+environment=production", "+application=donkey"], { "db": {"driver": "mysql", "user": "mysql", "pass": "r4Zn*jQ9JB1Rz2kfz"}, "donkey": {"name": "kong", "rank": "king"}, @@ -163,9 +163,9 @@ def test_strict_failure_disabled_on_call( ( None, [ - "db=mysql", - "environment=production", - "application=donkey", + "+db=mysql", + "+environment=production", + "+application=donkey", "donkey.name=Dapple", "donkey.rank=squire_donkey", ], diff --git a/tests/test_config_loader.py b/tests/test_config_loader.py index a440cd8c4c5..799f83d8192 100644 --- a/tests/test_config_loader.py +++ b/tests/test_config_loader.py @@ -1,13 +1,14 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved -# TODO: print error if source package is not found: python two_packages.py db@MISSING:source1=mysql -# TODO: Implement and test: https://docs.google.com/document/d/1I--p8JpIWQujVZuyaM2J910ew9wJ01S0E3ye6uJnTmY/edit# # TODO: bad error for: # python examples/tutorials/basic/your_first_hydra_app/5_selecting_defaults_for_config_groups/my_app.py db= -# TODO: final verdict about + and - prefixes for config groups. -# TODO: decide on header comments before the package. -# TODO fix error in: python two_packages.py db@destination:backup - +# TODO : If not config file is specified, do not require + prefix to add items to defaults or config. +# TODO: Document command line: +# +/~, pacakges, defaults manipulation, the works. +# completion +# TODO: Add tests for completion with +prefix (should complete and suggest config groups that are not listed) +# TODO : Test completion when defaults has a missing mandatory item +import re from dataclasses import dataclass from typing import Any, List @@ -18,6 +19,7 @@ from hydra._internal.config_loader_impl import ( ConfigLoaderImpl, DefaultElement, + ParsedConfigOverride, ParsedOverride, ) from hydra._internal.utils import create_config_search_path @@ -162,7 +164,7 @@ def test_load_changing_group_and_package_in_default( id="baseline", ), pytest.param( - ["group1@pkg3=option1"], + ["+group1@pkg3=option1"], { "pkg1": {"group1_option1": True}, "pkg2": {"group1_option1": True}, @@ -175,6 +177,11 @@ def test_load_changing_group_and_package_in_default( {"pkg2": {"group1_option1": True}}, id="delete_package", ), + pytest.param( + ["~group1@pkg1"], + {"pkg2": {"group1_option1": True}}, + id="delete_package", + ), pytest.param( ["group1@pkg1:new_pkg=option1"], {"new_pkg": {"group1_option1": True}, "pkg2": {"group1_option1": True}}, @@ -206,7 +213,7 @@ def test_load_adding_group_not_in_default(self, path: str) -> None: ) cfg = config_loader.load_configuration( config_name="optional-default.yaml", - overrides=["group2=file1"], + overrides=["+group2=file1"], strict=False, ) with open_dict(cfg): @@ -335,7 +342,7 @@ def test_load_config_with_schema(self, restore_singletons: Any, path: str) -> No ) cfg = config_loader.load_configuration( - config_name="config", overrides=["db=mysql"] + config_name="config", overrides=["+db=mysql"] ) with open_dict(cfg): del cfg["hydra"] @@ -377,7 +384,7 @@ def test_load_config_file_with_schema_validation( config_search_path=create_config_search_path(path) ) cfg = config_loader.load_configuration( - config_name="config", overrides=["db=mysql"], strict=False + config_name="config", overrides=["+db=mysql"], strict=False ) with open_dict(cfg): @@ -692,7 +699,7 @@ def test_overlapping_schemas(restore_singletons: Any) -> None: assert OmegaConf.get_type(cfg.plugin) == Plugin cfg = config_loader.load_configuration( - config_name="config", overrides=["plugin=concrete"] + config_name="config", overrides=["+plugin=concrete"] ) with open_dict(cfg): del cfg["hydra"] @@ -760,7 +767,7 @@ def test_complex_defaults(overrides: Any, expected: Any) -> None: @pytest.mark.parametrize( # type: ignore "override, expected", [ - # changing items + # changing item pytest.param( "db=postgresql", ParsedOverride(None, "db", None, None, "postgresql"), @@ -786,7 +793,7 @@ def test_complex_defaults(overrides: Any, expected: Any) -> None: ParsedOverride(None, "db", "src", "dest", None), id="change_package", ), - # adding items + # adding item pytest.param( "+model=resnet", ParsedOverride("+", "model", None, None, "resnet"), @@ -797,32 +804,21 @@ def test_complex_defaults(overrides: Any, expected: Any) -> None: ParsedOverride("+", "db", "offsite_backup", None, "mysql"), id="add_item", ), - # deleting items + # deleting item pytest.param( - "-db", ParsedOverride("-", "db", None, None, None), id="delete_item", + "~db", ParsedOverride("~", "db", None, None, None), id="delete_item", ), pytest.param( - "-db@src", ParsedOverride("-", "db", "src", None, None), id="delete_item", + "~db@src", ParsedOverride("~", "db", "src", None, None), id="delete_item", ), pytest.param( - "db=null", ParsedOverride(None, "db", None, None, "null"), id="delete_item", + "db=null", ParsedOverride("~", "db", None, None, None), id="delete_item", ), pytest.param( "db@src=null", - ParsedOverride(None, "db", "src", None, "null"), + ParsedOverride("~", "db", "src", None, None), id="delete_item", ), - # old - ("key@pkg=value", ParsedOverride(None, "key", "pkg", None, "value")), - ("key@pkg1:pkg2=value", ParsedOverride(None, "key", "pkg1", "pkg2", "value")), - ( - "key@a.b.c:x.y.z=value", - ParsedOverride(None, "key", "a.b.c", "x.y.z", "value"), - ), - ("key@:pkg2=value", ParsedOverride(None, "key", "", "pkg2", "value")), - ("key@pkg1:=value", ParsedOverride(None, "key", "pkg1", "", "value")), - ("key=null", ParsedOverride(None, "key", None, None, "null")), - ("foo/bar=zoo", ParsedOverride(None, "foo/bar", None, None, "zoo")), ], ) def test_parse_override(override: str, expected: ParsedOverride) -> None: @@ -830,13 +826,32 @@ def test_parse_override(override: str, expected: ParsedOverride) -> None: assert ret.override == expected +@pytest.mark.parametrize( # type: ignore + "override, expected", + [ + pytest.param( + "x.y.z=abc", ParsedConfigOverride(None, "x.y.z", "abc"), id="change_option", + ), + pytest.param( + "+x.y.z=abc", ParsedConfigOverride("+", "x.y.z", "abc"), id="adding", + ), + pytest.param( + "~x.y.z=abc", ParsedConfigOverride("~", "x.y.z", "abc"), id="adding", + ), + ], +) +def test_parse_config_override(override: str, expected: ParsedConfigOverride) -> None: + ret = ConfigLoaderImpl._parse_config_override(override) + assert ret == expected + + defaults_list = [{"db": "mysql"}, {"db@src": "mysql"}, {"hydra/launcher": "basic"}] -@pytest.mark.parametrize( +@pytest.mark.parametrize( # type: ignore "input_defaults,overrides,expected", [ - # change + # change item pytest.param( defaults_list, ["db=postgresql"], @@ -865,6 +880,17 @@ def test_parse_override(override: str, expected: ParsedOverride) -> None: [{"db": "mysql"}, {"db@dest": "postgresql"}, {"hydra/launcher": "basic"}], id="change_both", ), + pytest.param( + defaults_list, + ["db@XXX:dest=postgresql"], + pytest.raises( + HydraException, + match=re.escape( + "Could not rename package. No match for 'db@XXX' in the defaults list." + ), + ), + id="change_both_invalid_package", + ), pytest.param( defaults_list, ["db@:dest"], @@ -877,39 +903,226 @@ def test_parse_override(override: str, expected: ParsedOverride) -> None: [{"db": "mysql"}, {"db@dest": "mysql"}, {"hydra/launcher": "basic"}], id="change_package", ), + pytest.param( + defaults_list, + ["db@XXX:dest"], + pytest.raises( + HydraException, + match=re.escape( + "Could not rename package. No match for 'db@XXX' in the defaults list." + ), + ), + id="change_package_from_invalid", + ), + # adding item + pytest.param([], ["+db=mysql"], [{"db": "mysql"}], id="adding_item"), + pytest.param( + defaults_list, + ["+db@backup=mysql"], + [ + {"db": "mysql"}, + {"db@src": "mysql"}, + {"hydra/launcher": "basic"}, + {"db@backup": "mysql"}, + ], + id="adding_item_at_package", + ), + pytest.param( + defaults_list, + ["+db=mysql"], + pytest.raises( + HydraException, + match=re.escape( + "Could not add. An item matching 'db' is already in the defaults list" + ), + ), + id="adding_duplicate_item", + ), + pytest.param( + defaults_list, + ["+db@src:foo=mysql"], + pytest.raises( + HydraException, + match=re.escape( + "Add syntax does not support package rename, remove + prefix" + ), + ), + id="add_rename_error", + ), + pytest.param( + defaults_list, + ["+db@src=mysql"], + pytest.raises( + HydraException, + match=re.escape( + "Could not add. An item matching 'db@src' is already in the defaults list" + ), + ), + id="adding_duplicate_item", + ), + pytest.param( + [], + ["db=mysql"], + pytest.raises( + HydraException, + match=re.escape( + "Could not override. No match for 'db' in the defaults list.\n" + "To append to your default list, prefix the override with plus. e.g +db=mysql" + ), + ), + id="adding_without_plus", + ), + # deleting item + pytest.param( + [], + ["~db=mysql"], + pytest.raises( + HydraException, + match=re.escape( + "Could not delete. No match for 'db' in the defaults list." + ), + ), + id="delete_no_match", + ), + pytest.param( + defaults_list, + ["~db"], + [{"db@src": "mysql"}, {"hydra/launcher": "basic"}], + id="delete", + ), + pytest.param( + defaults_list, + ["~db=mysql"], + [{"db@src": "mysql"}, {"hydra/launcher": "basic"}], + id="delete", + ), + pytest.param( + defaults_list, + ["~db=postgresql"], + pytest.raises( + HydraException, + match=re.escape( + "Could not delete. No match for 'db=postgresql' in the defaults list." + ), + ), + id="delete_mismatch_value", + ), + pytest.param( + defaults_list, + ["~db@src"], + [{"db": "mysql"}, {"hydra/launcher": "basic"}], + id="delete", + ), ], ) def test_apply_overrides_to_defaults( - input_defaults: List[str], overrides: List[str], expected: List[Any], + input_defaults: List[str], overrides: List[str], expected: Any ) -> None: - parsed_overrides = [ - ConfigLoaderImpl._parse_override(override).override for override in overrides - ] - input_defaults = ConfigLoaderImpl._parse_defaults( + defaults = ConfigLoaderImpl._parse_defaults( OmegaConf.create({"defaults": input_defaults}) ) - expected_defaults = ConfigLoaderImpl._parse_defaults( - OmegaConf.create({"defaults": expected}) - ) - ConfigLoaderImpl._apply_overrides_to_defaults( - overrides=parsed_overrides, defaults=input_defaults - ) - assert input_defaults == expected_defaults + parsed_overrides = [ + ConfigLoaderImpl._parse_override(override) for override in overrides + ] + + if isinstance(expected, list): + expected_defaults = ConfigLoaderImpl._parse_defaults( + OmegaConf.create({"defaults": expected}) + ) + ConfigLoaderImpl._apply_overrides_to_defaults( + overrides=parsed_overrides, defaults=defaults + ) + assert defaults == expected_defaults + else: + with expected: + ConfigLoaderImpl._apply_overrides_to_defaults( + overrides=parsed_overrides, defaults=defaults + ) @pytest.mark.parametrize( # type: ignore - "overrides, expected", + "input_cfg,strict,overrides,expected", [ - # interpreted as a dotlist - (["user@hostname=active"], {"user@hostname": "active"}), - (["user@hostname.com=active"], {"user@hostname": {"com": "active"}}), + # append + pytest.param({}, False, ["x=10"], {"x": 10}, id="append"), + pytest.param( + {}, + True, + ["x=10"], + pytest.raises( + HydraException, match=re.escape("Error merging override x=10") + ), + id="append", + ), + pytest.param({}, True, ["+x=10"], {"x": 10}, id="append"), + # append item with @ + pytest.param( + {}, + False, + ["user@hostname=active"], + {"user@hostname": "active"}, + id="append@", + ), + pytest.param( + {}, + True, + ["+user@hostname=active"], + {"user@hostname": "active"}, + id="append@", + ), + # override + pytest.param({"x": 20}, False, ["x=10"], {"x": 10}, id="override"), + pytest.param({"x": 20}, True, ["x=10"], {"x": 10}, id="override"), + pytest.param( + {"x": 20}, + True, + ["+x=10"], + pytest.raises( + HydraException, + match=re.escape( + "Could not append to config. An item is already at 'x'" + ), + ), + id="override", + ), + # delete + pytest.param({"x": 20}, False, ["~x"], {}, id="delete"), + pytest.param({"x": 20}, False, ["~x=20"], {}, id="delete"), + pytest.param({"x": {"y": 10}}, False, ["~x"], {}, id="delete"), + pytest.param({"x": {"y": 10}}, False, ["~x.y"], {"x": {}}, id="delete"), + pytest.param({"x": {"y": 10}}, False, ["~x.y=10"], {"x": {}}, id="delete"), + pytest.param( + {"x": 20}, + False, + ["~z"], + pytest.raises( + HydraException, + match=re.escape("Could not delete from config. 'z' does not exist."), + ), + id="delete_error_key", + ), + pytest.param( + {"x": 20}, + False, + ["~x=10"], + pytest.raises( + HydraException, + match=re.escape( + "Could not delete from config. The value of 'x' is 20 and not 10." + ), + ), + id="delete_error_value", + ), ], ) -def test_override_config_key_with_at_symbol( - overrides: List[str], expected: Any +def test_apply_overrides_to_config( + input_cfg: Any, strict: bool, overrides: List[str], expected: Any ) -> None: - config_loader = ConfigLoaderImpl(config_search_path=create_config_search_path(None)) - cfg = config_loader.load_configuration(config_name=None, overrides=overrides) - with open_dict(cfg): - del cfg["hydra"] - assert cfg == expected + cfg = OmegaConf.create(input_cfg) + OmegaConf.set_struct(cfg, strict) + if isinstance(expected, dict): + ConfigLoaderImpl._apply_overrides_to_config(overrides=overrides, cfg=cfg) + assert cfg == expected + else: + with expected: + ConfigLoaderImpl._apply_overrides_to_config(overrides=overrides, cfg=cfg) diff --git a/tests/test_examples/test_structured_configs_tutorial.py b/tests/test_examples/test_structured_configs_tutorial.py index 69e6820557f..c2e843a5aa9 100644 --- a/tests/test_examples/test_structured_configs_tutorial.py +++ b/tests/test_examples/test_structured_configs_tutorial.py @@ -92,7 +92,10 @@ def test_structured_configs_2_nesting_configs__with_ad_hoc_node(tmpdir: Path) -> "overrides,expected", [ ([], {"db": "???"}), - (["db=mysql"], {"db": {"driver": "mysql", "host": "localhost", "port": 3306}},), + ( + ["+db=mysql"], + {"db": {"driver": "mysql", "host": "localhost", "port": 3306}}, + ), ], ) def test_structured_configs_3_config_groups( @@ -114,7 +117,7 @@ def test_structured_configs_3_config_groups_with_inheritance(tmpdir: Path) -> No sys.executable, "examples/tutorials/structured_configs/3_config_groups/my_app_with_inheritance.py", "hydra.run.dir=" + str(tmpdir), - "db=mysql", + "+db=mysql", ] result = check_output(cmd) assert result.decode().rstrip() == "Connecting to MySQL: localhost:3306" diff --git a/tests/test_examples/test_tutorials_basic.py b/tests/test_examples/test_tutorials_basic.py index e9a1adafe5c..1bed0728690 100644 --- a/tests/test_examples/test_tutorials_basic.py +++ b/tests/test_examples/test_tutorials_basic.py @@ -129,7 +129,7 @@ def test_tutorial_config_file_bad_key( [ ([], OmegaConf.create()), ( - ["db=postgresql"], + ["+db=postgresql"], OmegaConf.create( { "db": { diff --git a/tests/test_hydra.py b/tests/test_hydra.py index 66c4cf1750f..8a5bd286892 100644 --- a/tests/test_hydra.py +++ b/tests/test_hydra.py @@ -293,7 +293,7 @@ def test_app_with_config_groups__override_dataset__wrong( calling_module=calling_module, config_path="conf", config_name=None, - overrides=["optimizer=wrong_name"], + overrides=["+optimizer=wrong_name"], ): pass assert sorted(ex.value.options) == sorted(["adam", "nesterov"]) @@ -317,7 +317,7 @@ def test_app_with_config_groups__override_all_configs( calling_module=calling_module, config_path="conf", config_name=None, - overrides=["optimizer=adam", "optimizer.lr=10"], + overrides=["+optimizer=adam", "optimizer.lr=10"], ) as task: assert task.job_ret is not None and task.job_ret.cfg == dict( optimizer=dict(type="adam", lr=10, beta=0.01) @@ -440,7 +440,7 @@ def test_cfg(tmpdir: Path, flag: str, expected_keys: List[str]) -> None: (None, "tests.test_apps.app_with_config_with_free_group.my_app"), ], ) -@pytest.mark.parametrize("overrides", [["free_group=opt1,opt2"]]) # type: ignore +@pytest.mark.parametrize("overrides", [["+free_group=opt1,opt2"]]) # type: ignore def test_multirun_with_free_override( restore_singletons: Any, sweep_runner: TSweepRunner, @@ -459,9 +459,9 @@ def test_multirun_with_free_override( ) with sweep: assert sweep.returns is not None and len(sweep.returns[0]) == 2 - assert sweep.returns[0][0].overrides == ["free_group=opt1"] + assert sweep.returns[0][0].overrides == ["+free_group=opt1"] assert sweep.returns[0][0].cfg == {"group_opt1": True, "free_group_opt1": True} - assert sweep.returns[0][1].overrides == ["free_group=opt2"] + assert sweep.returns[0][1].overrides == ["+free_group=opt2"] assert sweep.returns[0][1].cfg == {"group_opt1": True, "free_group_opt2": True}