From 5700988b9ed9f13c49c767836f84325d6fa27c0f Mon Sep 17 00:00:00 2001 From: Omry Yadan Date: Thu, 3 Dec 2020 14:02:33 -0800 Subject: [PATCH] lint --- hydra/_internal/new_defaults_list.py | 26 +++++++++++++---------- hydra/core/new_default_element.py | 6 ++++-- hydra/plugins/config_source.py | 6 +++--- tests/defaults_list/test_defaults_list.py | 10 ++++----- 4 files changed, 27 insertions(+), 21 deletions(-) diff --git a/hydra/_internal/new_defaults_list.py b/hydra/_internal/new_defaults_list.py index d41f2c10281..d49a7d9cb20 100644 --- a/hydra/_internal/new_defaults_list.py +++ b/hydra/_internal/new_defaults_list.py @@ -31,13 +31,13 @@ class Deletion: @dataclass class Overrides: - override_choices: Dict[str, str] + override_choices: Dict[str, Optional[str]] override_used: Dict[str, bool] append_group_defaults: List[GroupDefault] config_overrides: List[Override] - known_choices: Dict[str, str] + known_choices: Dict[str, Optional[str]] deletions: Dict[str, Deletion] @@ -143,8 +143,9 @@ def is_deleted(self, default: InputDefault) -> bool: return True else: return deletion.name == default.get_name() + return False - def delete(self, default: InputDefault) -> bool: + def delete(self, default: InputDefault) -> None: assert isinstance(default, GroupDefault) default.deleted = True @@ -202,6 +203,7 @@ def _expand_virtual_root( root.children.append(gd) for d in reversed(root.children): + assert isinstance(d, InputDefault) new_root = DefaultsTreeNode(node=d, parent=root) d.parent_base_dir = "" d.parent_package = "" @@ -418,12 +420,12 @@ def _create_defaults_tree_impl( # processed deferred interpolations known_choices = _create_interpolation_map(overrides, defaults_list, self_added) - for idx, d in enumerate(children): - if isinstance(d, InputDefault) and d.is_interpolation(): - d.resolve_interpolation(known_choices) - new_root = DefaultsTreeNode(node=d, parent=root) - d.parent_base_dir = parent.get_group_path() - d.parent_package = parent.get_final_package() + for idx, dd in enumerate(children): + if isinstance(dd, InputDefault) and dd.is_interpolation(): + dd.resolve_interpolation(known_choices) + new_root = DefaultsTreeNode(node=dd, parent=root) + dd.parent_base_dir = parent.get_group_path() + dd.parent_package = parent.get_final_package() subtree = _create_defaults_tree_impl( repo=repo, root=new_root, @@ -469,7 +471,7 @@ def _create_result_default( def _dfs_walk( tree: DefaultsTreeNode, - operator: Callable[[DefaultsTreeNode, InputDefault], None], + operator: Callable[[Optional[DefaultsTreeNode], InputDefault], None], ) -> None: if tree.children is None or len(tree.children) == 0: operator(tree.parent, tree.node) @@ -489,7 +491,9 @@ class Collector: def __init__(self) -> None: self.output: List[ResultDefault] = [] - def __call__(self, tree_node: DefaultsTreeNode, node: InputDefault) -> None: + def __call__( + self, tree_node: Optional[DefaultsTreeNode], node: InputDefault + ) -> None: if node.is_deleted(): return diff --git a/hydra/core/new_default_element.py b/hydra/core/new_default_element.py index dfd36089f0d..1cff1cc0890 100644 --- a/hydra/core/new_default_element.py +++ b/hydra/core/new_default_element.py @@ -110,7 +110,7 @@ def _get_final_package( self, parent_package: Optional[str], package: Optional[str], - name: str, + name: Optional[str], ) -> str: assert parent_package is not None if package is None: @@ -287,7 +287,9 @@ def get_config_path(self) -> str: def get_final_package(self) -> str: return self._get_final_package( - self.parent_package, self.get_package(), self.get_name() + self.parent_package, + self.get_package(), + self.get_name(), ) def _relative_group_path(self) -> str: diff --git a/hydra/plugins/config_source.py b/hydra/plugins/config_source.py index b6e8b05d645..2919b15014f 100644 --- a/hydra/plugins/config_source.py +++ b/hydra/plugins/config_source.py @@ -7,7 +7,7 @@ from abc import abstractmethod from dataclasses import dataclass -from typing import List, Optional, Dict, Tuple, MutableSequence, Union +from typing import List, Optional, Dict, Tuple, MutableSequence from hydra.core import DefaultElement from hydra.core.new_default_element import InputDefault, ConfigDefault, GroupDefault @@ -415,7 +415,7 @@ def _create_new_defaults_list( return res @staticmethod - def _extract_raw_defaults_list(cfg: Container) -> Union[ListConfig, DictConfig]: + def _extract_raw_defaults_list(cfg: Container) -> ListConfig: empty = OmegaConf.create([]) if not OmegaConf.is_dict(cfg): return empty @@ -423,5 +423,5 @@ def _extract_raw_defaults_list(cfg: Container) -> Union[ListConfig, DictConfig]: with read_write(cfg): with open_dict(cfg): defaults = cfg.pop("defaults", empty) - + assert isinstance(defaults, ListConfig) return defaults diff --git a/tests/defaults_list/test_defaults_list.py b/tests/defaults_list/test_defaults_list.py index 200cdc43c66..140c2e143d8 100644 --- a/tests/defaults_list/test_defaults_list.py +++ b/tests/defaults_list/test_defaults_list.py @@ -1186,7 +1186,7 @@ def test_with_hydra_config( ) def test_experiment_use_case( config_name: str, overrides: List[str], expected: List[ResultDefault] -): +) -> None: _test_defaults_list_impl( config_name=config_name, overrides=overrides, @@ -1230,7 +1230,7 @@ def test_experiment_use_case( ) def test_as_as_primary( config_name: str, overrides: List[str], expected: List[ResultDefault] -): +) -> None: _test_defaults_list_impl( config_name=config_name, overrides=overrides, @@ -1300,7 +1300,7 @@ def test_as_as_primary( ) def test_placeholder( config_name: str, overrides: List[str], expected: List[ResultDefault] -): +) -> None: _test_defaults_list_impl( config_name=config_name, overrides=overrides, @@ -1373,7 +1373,7 @@ def test_placeholder( ) def test_interpolation_simple( config_name: str, overrides: List[str], expected: List[ResultDefault] -): +) -> None: _test_defaults_list_impl( config_name=config_name, overrides=overrides, @@ -1398,7 +1398,7 @@ def test_interpolation_simple( ) def test_deletion( config_name: str, overrides: List[str], expected: List[ResultDefault] -): +) -> None: _test_defaults_list_impl( config_name=config_name, overrides=overrides,