Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
omry committed Dec 3, 2020
1 parent 3f0a4ab commit 5700988
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 21 deletions.
26 changes: 15 additions & 11 deletions hydra/_internal/new_defaults_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ class Deletion:

@dataclass
class Overrides:
override_choices: Dict[str, str]
override_choices: Dict[str, Optional[str]]
override_used: Dict[str, bool]

append_group_defaults: List[GroupDefault]
config_overrides: List[Override]

known_choices: Dict[str, str]
known_choices: Dict[str, Optional[str]]

deletions: Dict[str, Deletion]

Expand Down Expand Up @@ -143,8 +143,9 @@ def is_deleted(self, default: InputDefault) -> bool:
return True
else:
return deletion.name == default.get_name()
return False

def delete(self, default: InputDefault) -> bool:
def delete(self, default: InputDefault) -> None:
assert isinstance(default, GroupDefault)
default.deleted = True

Expand Down Expand Up @@ -202,6 +203,7 @@ def _expand_virtual_root(
root.children.append(gd)

for d in reversed(root.children):
assert isinstance(d, InputDefault)
new_root = DefaultsTreeNode(node=d, parent=root)
d.parent_base_dir = ""
d.parent_package = ""
Expand Down Expand Up @@ -418,12 +420,12 @@ def _create_defaults_tree_impl(
# processed deferred interpolations
known_choices = _create_interpolation_map(overrides, defaults_list, self_added)

for idx, d in enumerate(children):
if isinstance(d, InputDefault) and d.is_interpolation():
d.resolve_interpolation(known_choices)
new_root = DefaultsTreeNode(node=d, parent=root)
d.parent_base_dir = parent.get_group_path()
d.parent_package = parent.get_final_package()
for idx, dd in enumerate(children):
if isinstance(dd, InputDefault) and dd.is_interpolation():
dd.resolve_interpolation(known_choices)
new_root = DefaultsTreeNode(node=dd, parent=root)
dd.parent_base_dir = parent.get_group_path()
dd.parent_package = parent.get_final_package()
subtree = _create_defaults_tree_impl(
repo=repo,
root=new_root,
Expand Down Expand Up @@ -469,7 +471,7 @@ def _create_result_default(

def _dfs_walk(
tree: DefaultsTreeNode,
operator: Callable[[DefaultsTreeNode, InputDefault], None],
operator: Callable[[Optional[DefaultsTreeNode], InputDefault], None],
) -> None:
if tree.children is None or len(tree.children) == 0:
operator(tree.parent, tree.node)
Expand All @@ -489,7 +491,9 @@ class Collector:
def __init__(self) -> None:
self.output: List[ResultDefault] = []

def __call__(self, tree_node: DefaultsTreeNode, node: InputDefault) -> None:
def __call__(
self, tree_node: Optional[DefaultsTreeNode], node: InputDefault
) -> None:
if node.is_deleted():
return

Expand Down
6 changes: 4 additions & 2 deletions hydra/core/new_default_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def _get_final_package(
self,
parent_package: Optional[str],
package: Optional[str],
name: str,
name: Optional[str],
) -> str:
assert parent_package is not None
if package is None:
Expand Down Expand Up @@ -287,7 +287,9 @@ def get_config_path(self) -> str:

def get_final_package(self) -> str:
return self._get_final_package(
self.parent_package, self.get_package(), self.get_name()
self.parent_package,
self.get_package(),
self.get_name(),
)

def _relative_group_path(self) -> str:
Expand Down
6 changes: 3 additions & 3 deletions hydra/plugins/config_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from abc import abstractmethod
from dataclasses import dataclass
from typing import List, Optional, Dict, Tuple, MutableSequence, Union
from typing import List, Optional, Dict, Tuple, MutableSequence

from hydra.core import DefaultElement
from hydra.core.new_default_element import InputDefault, ConfigDefault, GroupDefault
Expand Down Expand Up @@ -415,13 +415,13 @@ def _create_new_defaults_list(
return res

@staticmethod
def _extract_raw_defaults_list(cfg: Container) -> Union[ListConfig, DictConfig]:
def _extract_raw_defaults_list(cfg: Container) -> ListConfig:
empty = OmegaConf.create([])
if not OmegaConf.is_dict(cfg):
return empty
assert isinstance(cfg, DictConfig)
with read_write(cfg):
with open_dict(cfg):
defaults = cfg.pop("defaults", empty)

assert isinstance(defaults, ListConfig)
return defaults
10 changes: 5 additions & 5 deletions tests/defaults_list/test_defaults_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -1186,7 +1186,7 @@ def test_with_hydra_config(
)
def test_experiment_use_case(
config_name: str, overrides: List[str], expected: List[ResultDefault]
):
) -> None:
_test_defaults_list_impl(
config_name=config_name,
overrides=overrides,
Expand Down Expand Up @@ -1230,7 +1230,7 @@ def test_experiment_use_case(
)
def test_as_as_primary(
config_name: str, overrides: List[str], expected: List[ResultDefault]
):
) -> None:
_test_defaults_list_impl(
config_name=config_name,
overrides=overrides,
Expand Down Expand Up @@ -1300,7 +1300,7 @@ def test_as_as_primary(
)
def test_placeholder(
config_name: str, overrides: List[str], expected: List[ResultDefault]
):
) -> None:
_test_defaults_list_impl(
config_name=config_name,
overrides=overrides,
Expand Down Expand Up @@ -1373,7 +1373,7 @@ def test_placeholder(
)
def test_interpolation_simple(
config_name: str, overrides: List[str], expected: List[ResultDefault]
):
) -> None:
_test_defaults_list_impl(
config_name=config_name,
overrides=overrides,
Expand All @@ -1398,7 +1398,7 @@ def test_interpolation_simple(
)
def test_deletion(
config_name: str, overrides: List[str], expected: List[ResultDefault]
):
) -> None:
_test_defaults_list_impl(
config_name=config_name,
overrides=overrides,
Expand Down

0 comments on commit 5700988

Please sign in to comment.