Skip to content

Commit

Permalink
Override delete support (~group, ~group=choice)
Browse files Browse the repository at this point in the history
  • Loading branch information
omry committed Dec 3, 2020
1 parent c5bc56a commit 3f0a4ab
Show file tree
Hide file tree
Showing 6 changed files with 340 additions and 29 deletions.
125 changes: 104 additions & 21 deletions hydra/_internal/new_defaults_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import copy
import warnings
from dataclasses import dataclass
from dataclasses import dataclass, field
from textwrap import dedent
from typing import Dict, List, Optional, Union
from typing import Callable, Dict, List, Optional, Union

from omegaconf import DictConfig, OmegaConf

Expand All @@ -23,6 +23,12 @@
from hydra.errors import ConfigCompositionException


@dataclass
class Deletion:
name: Optional[str]
used: bool = field(default=False, compare=False)


@dataclass
class Overrides:
override_choices: Dict[str, str]
Expand All @@ -33,11 +39,14 @@ class Overrides:

known_choices: Dict[str, str]

deletions: Dict[str, Deletion]

def __init__(self, repo: IConfigRepository, overrides_list: List[Override]) -> None:
self.override_choices = {}
self.override_used = {}
self.append_group_defaults = []
self.config_overrides = []
self.deletions = {}

self.known_choices = {}

Expand All @@ -47,11 +56,21 @@ def __init__(self, repo: IConfigRepository, overrides_list: List[Override]) -> N
if not is_group:
self.config_overrides.append(override)
else:
if not isinstance(value, str):
if override.is_delete():
key = override.get_key_element()[1:]
value = override.value()
if value is not None and not isinstance(value, str):
raise ValueError(
f"Config group override deletion value must be a string : {override}"
)

self.deletions[key] = Deletion(name=value)

elif not isinstance(value, str):
raise ValueError(
f"Config group override must be a string : {override}"
)
if override.is_add():
elif override.is_add():
self.append_group_defaults.append(
GroupDefault(
group=override.key_or_group,
Expand Down Expand Up @@ -95,6 +114,13 @@ def ensure_overrides_used(self) -> None:
)
raise ConfigCompositionException(msg)

def ensure_deletions_used(self) -> None:
for key, deletion in self.deletions.items():
if not deletion.used:
desc = f"{key}={deletion.name}" if deletion.name is not None else key
msg = f"Could not delete '{desc}'. No match in the defaults list"
raise ConfigCompositionException(msg)

def set_known_choice(self, default: InputDefault) -> None:
if isinstance(default, GroupDefault):
key = default.get_override_key()
Expand All @@ -107,14 +133,32 @@ def set_known_choice(self, default: InputDefault) -> None:
f"Internal error, value of {key} is being changed from {prev} to {default.get_name()}"
)

def is_deleted(self, default: InputDefault) -> bool:
if not isinstance(default, GroupDefault):
return False
key = default.get_override_key()
if key in self.deletions:
deletion = self.deletions[key]
if deletion.name is None:
return True
else:
return deletion.name == default.get_name()

def delete(self, default: InputDefault) -> bool:
assert isinstance(default, GroupDefault)
default.deleted = True

key = default.get_override_key()
self.deletions[key].used = True


@dataclass
class DefaultsList:
defaults: List[ResultDefault]
config_overrides: List[Override]


def _validate_self(containing_node: InputDefault, defaults: List[InputDefault]) -> None:
def _validate_self(containing_node: InputDefault, defaults: List[InputDefault]) -> bool:
# check that self is present only once
has_self = False
for d in defaults:
Expand Down Expand Up @@ -162,7 +206,7 @@ def _expand_virtual_root(
d.parent_base_dir = ""
d.parent_package = ""

subtree = _create_defaults_tree(
subtree = _create_defaults_tree_impl(
repo=repo,
root=new_root,
is_primary_config=False,
Expand Down Expand Up @@ -237,6 +281,26 @@ def _create_defaults_tree(
skip_missing: bool,
interpolated_subtree: bool,
overrides: Overrides,
) -> DefaultsTreeNode:
ret = _create_defaults_tree_impl(
repo=repo,
root=root,
is_primary_config=is_primary_config,
skip_missing=skip_missing,
interpolated_subtree=interpolated_subtree,
overrides=overrides,
)

return ret


def _create_defaults_tree_impl(
repo: IConfigRepository,
root: DefaultsTreeNode,
is_primary_config: bool,
skip_missing: bool,
interpolated_subtree: bool,
overrides: Overrides,
) -> DefaultsTreeNode:
parent = root.node
children: List[Union[InputDefault, DefaultsTreeNode]] = []
Expand All @@ -261,6 +325,11 @@ def _create_defaults_tree(
repo=repo, node=parent, is_primary_config=is_primary_config
)

if overrides.is_deleted(parent):
overrides.delete(parent)
# parent.deleted = True
return root

overrides.set_known_choice(parent)

if parent.get_name() is None:
Expand Down Expand Up @@ -333,7 +402,7 @@ def _create_defaults_tree(
children.append(d)
continue

subtree = _create_defaults_tree(
subtree = _create_defaults_tree_impl(
repo=repo,
root=new_root,
is_primary_config=False,
Expand All @@ -355,7 +424,7 @@ def _create_defaults_tree(
new_root = DefaultsTreeNode(node=d, parent=root)
d.parent_base_dir = parent.get_group_path()
d.parent_package = parent.get_final_package()
subtree = _create_defaults_tree(
subtree = _create_defaults_tree_impl(
repo=repo,
root=new_root,
is_primary_config=False,
Expand Down Expand Up @@ -398,25 +467,39 @@ def _create_result_default(
return res


def _tree_to_list(
def _dfs_walk(
tree: DefaultsTreeNode,
output: List[ResultDefault],
operator: Callable[[DefaultsTreeNode, InputDefault], None],
) -> None:
node = tree.node

if tree.children is None or len(tree.children) == 0:
rd = _create_result_default(tree=tree.parent, node=node)
if rd is not None:
output.append(rd)
operator(tree.parent, tree.node)
else:
for child in tree.children:
if isinstance(child, InputDefault):
rd = _create_result_default(tree=tree, node=child)
if rd is not None:
output.append(rd)
operator(tree, child)
else:
assert isinstance(child, DefaultsTreeNode)
_tree_to_list(tree=child, output=output)
_dfs_walk(tree=child, operator=operator)


def _tree_to_list(
tree: DefaultsTreeNode,
) -> List[ResultDefault]:
class Collector:
def __init__(self) -> None:
self.output: List[ResultDefault] = []

def __call__(self, tree_node: DefaultsTreeNode, node: InputDefault) -> None:
if node.is_deleted():
return

rd = _create_result_default(tree=tree_node, node=node)
if rd is not None:
self.output.append(rd)

visitor = Collector()
_dfs_walk(tree, visitor)
return visitor.output


def _create_root(config_name: str, with_hydra: bool) -> DefaultsTreeNode:
Expand Down Expand Up @@ -452,8 +535,7 @@ def _create_defaults_list(
skip_missing=skip_missing,
)

output: List[ResultDefault] = []
_tree_to_list(tree=defaults_tree, output=output)
output = _tree_to_list(tree=defaults_tree)
# TODO: fail if duplicate items exists
return output

Expand Down Expand Up @@ -482,6 +564,7 @@ def create_defaults_list(
skip_missing=skip_missing,
)
overrides.ensure_overrides_used()
overrides.ensure_deletions_used()
ret = DefaultsList(defaults=defaults, config_overrides=overrides.config_overrides)
return ret

Expand Down
18 changes: 13 additions & 5 deletions hydra/core/new_default_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ def _get_parent_package(self) -> Optional[str]:
def is_virtual(self) -> bool:
return False

def is_deleted(self) -> bool:
if "deleted" in self.__dict__:
return bool(self.__dict__["deleted"])
else:
return False

def set_package_header(self, package_header: str) -> None:
assert self.__dict__["package_header"] is None
# package header is always interpreted as absolute.
Expand Down Expand Up @@ -141,7 +147,11 @@ def __repr__(self) -> str:
for attr in attr_names:
value = getattr(self, attr)
if value is not None:
attrs.append(f'{attr}="{value}"')
if isinstance(value, str):
svalue = f'"{value}"'
else:
svalue = value
attrs.append(f"{attr}={svalue}")

flags = []
flag_names = self._get_flags()
Expand Down Expand Up @@ -185,9 +195,6 @@ def get_group_path(self) -> str:
def get_config_path(self) -> str:
return "<root>"

def get_default_package(self) -> str:
return self.get_group_path().replace("/", ".")

def get_final_package(self) -> str:
raise NotImplementedError()

Expand Down Expand Up @@ -319,6 +326,7 @@ class GroupDefault(InputDefault):
package: Optional[str] = None

override: bool = False
deleted: Optional[bool] = None

config_name_overridden: bool = field(default=False, compare=False, repr=False)

Expand Down Expand Up @@ -368,7 +376,7 @@ def _relative_group_path(self) -> str:
return self.group

def _get_attributes(self) -> List[str]:
return ["group", "name", "package"]
return ["group", "name", "package", "deleted"]

def _get_flags(self) -> List[str]:
return ["optional", "override"]
Expand Down
2 changes: 2 additions & 0 deletions tests/defaults_list/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def _test_defaults_tree_impl(
skip_missing=skip_missing,
)
overrides.ensure_overrides_used()
overrides.ensure_deletions_used()
assert result == expected
else:
with expected:
Expand All @@ -53,3 +54,4 @@ def _test_defaults_tree_impl(
skip_missing=skip_missing,
)
overrides.ensure_overrides_used()
overrides.ensure_deletions_used()
36 changes: 34 additions & 2 deletions tests/defaults_list/test_defaults_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,18 +66,25 @@
# - (Y) Test interpolated config with a defaults list
# - (Y) Error if interpolated config defaults list has an overrides
# - (Y) Support and deprecate legacy defaults list interpolation style
# TODO: Consider delete support
# TODO: delete support from overrides
# - (Y) override to null from defaults list
# - (Y) Support delete by group
# - (Y) Support delete by group=value
# - (Y) Test deletion with a specific package
# - (Y) Deletion test with final defaults list
# TODO: Consider package rename support


# TODO: Error handling:
# - Error handling for entries that failed to override anything
# - Error if delete override did not delete anything
# - Duplicate _self_ error
# - test handling missing configs mentioned in defaults list (with and without optional)
# - Ambiguous overrides should provide valid override keys for group
# - Test deprecation message when attempting to override hydra configs without override: true
# - Should duplicate entries in results list be an error? (same override key)


# TODO: Integrate with Hydra
# - replace old defaults list computation
# - enable --info=defaults output
Expand Down Expand Up @@ -156,7 +163,7 @@ def _test_defaults_list_impl(
parser = OverridesParser.create()
repo = create_repo()
overrides_list = parser.parse_overrides(overrides=overrides)
if isinstance(expected, list):
if isinstance(expected, list) or expected is None:
result = create_defaults_list(
repo=repo,
config_name=config_name,
Expand Down Expand Up @@ -1372,3 +1379,28 @@ def test_interpolation_simple(
overrides=overrides,
expected=expected,
)


@mark.parametrize( # type: ignore
"config_name,overrides,expected",
[
param(
"include_nested_group",
["~group1"],
[
ResultDefault(
config_path="include_nested_group", package="", is_self=True
),
],
id="delete:include_nested_group:group1",
),
],
)
def test_deletion(
config_name: str, overrides: List[str], expected: List[ResultDefault]
):
_test_defaults_list_impl(
config_name=config_name,
overrides=overrides,
expected=expected,
)
Loading

0 comments on commit 3f0a4ab

Please sign in to comment.