Skip to content

Commit

Permalink
CLI add and delete support for defaults and config
Browse files Browse the repository at this point in the history
  • Loading branch information
omry committed May 21, 2020
1 parent 816dc0f commit ff6eb43
Show file tree
Hide file tree
Showing 13 changed files with 489 additions and 138 deletions.
4 changes: 2 additions & 2 deletions examples/notebook/%run_test.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 11,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -44,7 +44,7 @@
}
],
"source": [
"%run ../tutorials/basic/5_composition/my_app.py"
"%run ../tutorials/basic/your_first_hydra_app/6_putting_it_all_together/my_app.py"
]
}
],
Expand Down
4 changes: 2 additions & 2 deletions examples/notebook/hydra_notebook_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@
"source": [
"# Compose a config from scratch\n",
"# Missing actual user and password as those are environment specific\n",
"cfg=hydra.experimental.compose(overrides=[\"db=sqlite\"])\n",
"cfg=hydra.experimental.compose(overrides=[\"+db=sqlite\"])\n",
"print(cfg.pretty())"
]
},
Expand All @@ -117,7 +117,7 @@
],
"source": [
"# compose application specific config (in this case the applicaiton is \"donkey\")\n",
"cfg=hydra.experimental.compose(overrides= [\"db=mysql\", \"environment=production\", \"application=donkey\"])\n",
"cfg=hydra.experimental.compose(overrides= [\"+db=mysql\", \"+environment=production\", \"+application=donkey\"])\n",
"print(cfg.pretty())"
]
},
Expand Down
233 changes: 194 additions & 39 deletions hydra/_internal/config_loader_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,32 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple

from omegaconf import DictConfig, ListConfig, OmegaConf, open_dict
import yaml
from omegaconf import DictConfig, ListConfig, OmegaConf, _utils, open_dict
from omegaconf.errors import OmegaConfBaseException

from hydra._internal.config_repository import ConfigRepository
from hydra.core.config_loader import ConfigLoader, LoadTrace
from hydra.core.config_search_path import ConfigSearchPath
from hydra.core.object_type import ObjectType
from hydra.core.utils import JobRuntime, get_overrides_dirname
from hydra.core.utils import JobRuntime
from hydra.errors import HydraException, MissingConfigException
from hydra.plugins.config_source import ConfigLoadError, ConfigSource


@dataclass
class ParsedConfigOverride:
prefix: Optional[str]
key: str
value: str

def is_delete(self) -> bool:
return self.prefix == "~"

def is_add(self) -> bool:
return self.prefix == "+"


@dataclass
class ParsedOverride:
prefix: Optional[str]
Expand All @@ -29,14 +43,27 @@ class ParsedOverride:
pkg2: Optional[str]
value: Optional[str]

def get_source_package(self) -> Optional[str]:
return self.pkg1

def get_subject_package(self) -> Optional[str]:
return self.pkg1 if self.pkg2 is None else self.pkg2

def get_source_item(self) -> str:
pkg = self.get_source_package()
if pkg is None:
return self.key
else:
return f"{self.key}@{pkg}"

def is_package_rename(self) -> bool:
return self.pkg1 is not None and self.pkg2 is not None

def is_delete(self) -> bool:
return self.value == "null"
return self.prefix == "~" or self.value == "null"

def is_add(self) -> bool:
return self.prefix == "+"


@dataclass
Expand Down Expand Up @@ -92,14 +119,14 @@ def __init__(

def split_overrides(
self, pairs: List[ParsedOverrideWithLine],
) -> Tuple[List[ParsedOverride], List[ParsedOverrideWithLine]]:
) -> Tuple[List[ParsedOverrideWithLine], List[ParsedOverrideWithLine]]:
config_group_overrides = []
config_overrides = []
for pwd in pairs:
if not self.repository.exists(pwd.override.key):
config_overrides.append(pwd)
else:
config_group_overrides.append(pwd.override)
config_group_overrides.append(pwd)
return config_group_overrides, config_overrides

def load_configuration(
Expand Down Expand Up @@ -165,29 +192,26 @@ def load_configuration(
OmegaConf.set_struct(cfg, strict)

# Merge all command line overrides after enabling strict flag
try:
lst = [x.input_line for x in config_overrides]
merged = OmegaConf.merge(cfg, OmegaConf.from_dotlist(lst))
assert isinstance(merged, DictConfig)
cfg = merged
except OmegaConfBaseException as ex:
raise HydraException("Error merging overrides") from ex
ConfigLoaderImpl._apply_overrides_to_config(
[x.input_line for x in config_overrides], cfg
)

app_overrides = []
for pwl in parsed_overrides:
override = pwl.override
assert override.value is not None
assert override.key is not None
key = override.key
if key.startswith("hydra.") or key.startswith("hydra/"):
cfg.hydra.overrides.hydra.append(pwl.input_line)
else:
cfg.hydra.overrides.task.append(pwl.input_line)
app_overrides.append(pwl)

with open_dict(cfg.hydra.job):
if "name" not in cfg.hydra.job:
cfg.hydra.job.name = JobRuntime().get("name")
cfg.hydra.job.override_dirname = get_overrides_dirname(
input_list=cfg.hydra.overrides.task,
input_list=app_overrides,
kv_sep=cfg.hydra.job.config.override_dirname.kv_sep,
item_sep=cfg.hydra.job.config.override_dirname.item_sep,
exclude_keys=cfg.hydra.job.config.override_dirname.exclude_keys,
Expand Down Expand Up @@ -255,7 +279,7 @@ def find_matches(

@staticmethod
def _apply_overrides_to_defaults(
overrides: List[ParsedOverride], defaults: List[DefaultElement]
overrides: List[ParsedOverrideWithLine], defaults: List[DefaultElement],
) -> None:

key_to_defaults: Dict[str, List[IndexedDefaultElement]] = defaultdict(list)
Expand All @@ -265,35 +289,70 @@ def _apply_overrides_to_defaults(
key_to_defaults[default.config_group].append(
IndexedDefaultElement(idx=idx, default=default)
)
for override in overrides:
for owl in overrides:
override = owl.override
if override.is_add() and override.is_package_rename():
raise HydraException(
"Add syntax does not support package rename, remove + prefix"
)

if override.value is not None and "," in override.value:
# If this is a multirun config (comma separated list), flag the default to prevent it from being
# loaded until we are constructing the config for individual jobs.
override.value = "_SKIP_"

if override.value == "null":
matches = ConfigLoaderImpl.find_matches(key_to_defaults, override)
matches = ConfigLoaderImpl.find_matches(key_to_defaults, override)

if override.is_delete():
src = override.get_source_item()
if len(matches) == 0:
raise HydraException(
f"Could not delete. No match for '{src}' in the defaults list."
)
for pair in matches:
if (
override.value is not None
and override.value != defaults[pair.idx].config_name
):
raise HydraException(
f"Could not delete. No match for '{src}={override.value}' in the defaults list."
)

del defaults[pair.idx]
elif override.is_add():
if len(matches) > 0:
src = override.get_source_item()
raise HydraException(
f"Could not add. An item matching '{src}' is already in the defaults list."
)
assert override.value is not None
defaults.append(
DefaultElement(
config_group=override.key,
config_name=override.value,
package=override.get_subject_package(),
)
)
else:
matches = ConfigLoaderImpl.find_matches(key_to_defaults, override)

# override
for match in matches:
default = match.default
if override.value is not None:
default.config_name = override.value
if override.pkg1 is not None:
default.package = override.get_subject_package()
if len(matches) == 0 and not (
override.is_package_rename() or override.is_delete()
):
defaults.append(
DefaultElement(
config_group=override.key,
config_name=override.value,
package=override.get_subject_package(),

if len(matches) == 0:
src = override.get_source_item()
if override.is_package_rename():
msg = f"Could not rename package. No match for '{src}' in the defaults list."
else:
msg = (
f"Could not override. No match for '{src}' in the defaults list."
f"\nTo append to your default list, prefix the override with plus. e.g +{owl.input_line}"
)
)

raise HydraException(msg)

@staticmethod
def _split_group(group_with_package: str) -> Tuple[str, Optional[str]]:
Expand All @@ -309,36 +368,116 @@ def _split_group(group_with_package: str) -> Tuple[str, Optional[str]]:

return group, package

@staticmethod
def _apply_overrides_to_config(overrides: List[str], cfg: DictConfig) -> None:
loader = _utils.get_yaml_loader()

def get_value(val: str) -> Any:
return yaml.load(val, Loader=loader) if override.value is not None else None

for line in overrides:
override = ConfigLoaderImpl._parse_config_override(line)
try:
value = get_value(override.value)
if override.is_delete():
val = OmegaConf.select(cfg, override.key, throw_on_missing=False)
if val is None:
raise HydraException(
f"Could not delete from config. '{override.key}' does not exist."
)
elif value is not None and value != val:
raise HydraException(
f"Could not delete from config."
f" The value of '{override.key}' is {val} and not {override.value}."
)

key = override.key
last_dot = key.rfind(".")
if last_dot == -1:
del cfg[key]
else:
node = OmegaConf.select(cfg, key[0:last_dot])
del node[key[last_dot + 1 :]]

elif override.is_add():
if (
OmegaConf.select(cfg, override.key, throw_on_missing=False)
is None
):
with open_dict(cfg):
OmegaConf.update(cfg, override.key, value)
else:
raise HydraException(
f"Could not append to config. An item is already at '{override.key}'."
)
else:
OmegaConf.update(cfg, override.key, value)
except OmegaConfBaseException as ex:
raise HydraException(f"Error merging override {line}") from ex

@staticmethod
def _parse_override(override: str) -> ParsedOverrideWithLine:
# forms:
# key=value
# key@pkg=value
# key@src_pkg:dst_pkg=value
# regex code and tests: https://regex101.com/r/LiV6Rf/13
# regex code and tests: https://regex101.com/r/LiV6Rf/14

regex = (
r"^(?P<prefix>[+-])?(?P<key>[A-Za-z0-9_.-/]+)(?:@(?P<pkg1>[A-Za-z0-9_\.-]*)"
r"^(?P<prefix>[+~])?(?P<key>[A-Za-z0-9_.-/]+)(?:@(?P<pkg1>[A-Za-z0-9_\.-]*)"
r"(?::(?P<pkg2>[A-Za-z0-9_\.-]*)?)?)?(?:=(?P<value>.*))?$"
)
matches = re.search(regex, override)

msg = (
f"Error parsing command line override : '{override}'"
f"\nAccepted forms:"
f"\n"
f"\n\tOverride:\tkey=value, key@package=value, key@src_pkg:dest_pkg=value, key@src_pkg:dest_pkg"
f"\n\tAppend:\t+key=value, +key@package=value"
f"\n\tDelete:\t~key@pkg, key@pkg=null"
)
if matches:
prefix = matches.group("prefix")
key = matches.group("key")
pkg1 = matches.group("pkg1")
pkg2 = matches.group("pkg2")
value = matches.group("value")
value: Optional[str] = matches.group("value")
if value == "null":
if prefix not in (None, "~"):
raise HydraException(msg)
prefix = "~"
value = None
ret = ParsedOverride(prefix, key, pkg1, pkg2, value)
return ParsedOverrideWithLine(override=ret, input_line=override)
else:
raise HydraException(
f"Error parsing command line override : '{override}'\n"
f"Accepted forms:\n"
f"\tkey=value\n"
f"\tkey@dest_pkg=value\n"
f"\tkey@src_pkg:dest_pkg=value"
raise HydraException(msg)

@staticmethod
def _parse_config_override(override: str) -> ParsedConfigOverride:
# forms:
# update: key=value
# append: +key=value
# delete: ~key=value | ~key
# regex code and tests: https://regex101.com/r/JAPVdx/4

regex = r"^(?P<prefix>[+~])?(?P<key>.*?)(=(?P<value>.*))?$"
matches = re.search(regex, override)
if matches:
return ParsedConfigOverride(
prefix=matches.group("prefix"),
key=matches.group("key"),
value=matches.group("value"),
)
else:
msg = (
f"Error parsing config override : '{override}'"
f"\nAccepted forms:"
f"\n"
f"\n\tOverride:\tkey=value"
f"\n\tAppend:\t+key=value"
f"\n\tDelete:\t~key@pkg"
)
raise HydraException(msg)

@staticmethod
def _combine_default_lists(
Expand Down Expand Up @@ -645,3 +784,19 @@ def _parse_defaults(cfg: DictConfig) -> List[DefaultElement]:

def get_sources(self) -> List[ConfigSource]:
return self.repository.get_sources()


def get_overrides_dirname(
input_list: List[ParsedOverrideWithLine],
exclude_keys: List[str] = [],
item_sep: str = ",",
kv_sep: str = "=",
) -> str:
lines = []
for x in input_list:
if x.override.key not in exclude_keys:
lines.append(x.input_line)

lines.sort()
ret = re.sub(pattern="[=]", repl=kv_sep, string=item_sep.join(lines))
return ret
Loading

0 comments on commit ff6eb43

Please sign in to comment.