Skip to content

Commit

Permalink
fixed errors 2
Browse files Browse the repository at this point in the history
  • Loading branch information
omry committed May 21, 2020
1 parent ff6eb43 commit c5390ef
Show file tree
Hide file tree
Showing 21 changed files with 77 additions and 37 deletions.
4 changes: 2 additions & 2 deletions hydra/_internal/config_loader_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def split_overrides(
config_group_overrides = []
config_overrides = []
for pwd in pairs:
if not self.repository.exists(pwd.override.key):
if not self.repository.group_exists(pwd.override.key):
config_overrides.append(pwd)
else:
config_group_overrides.append(pwd)
Expand All @@ -140,7 +140,7 @@ def load_configuration(

parsed_overrides = [self._parse_override(override) for override in overrides]

if config_name is not None and not self.repository.exists(config_name):
if config_name is not None and not self.repository.config_exists(config_name):
# TODO: handle schema as a special case
descs = [
f"\t{src.path} (from {src.provider})"
Expand Down
28 changes: 21 additions & 7 deletions hydra/_internal/config_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ def load_config(
is_primary_config: bool,
package_override: Optional[str] = None,
) -> Optional[ConfigResult]:
source = self._find_config(config_path=config_path)
source = self._find_object_source(
config_path=config_path, object_type=ObjectType.CONFIG
)
ret = None
if source is not None:
ret = source.load_config(
Expand All @@ -52,8 +54,11 @@ def load_config(
)
return ret

def exists(self, config_path: str) -> bool:
return self._find_config(config_path) is not None
def group_exists(self, config_path: str) -> bool:
return self._find_object_source(config_path, ObjectType.GROUP) is not None

def config_exists(self, config_path: str) -> bool: # TODO: rename to config_exists?
return self._find_object_source(config_path, ObjectType.CONFIG) is not None

def get_group_options(
self, group_name: str, results_filter: Optional[ObjectType] = ObjectType.CONFIG
Expand All @@ -69,12 +74,21 @@ def get_group_options(
def get_sources(self) -> List[ConfigSource]:
return self.sources

def _find_config(self, config_path: str) -> Optional[ConfigSource]:
def _find_object_source(
self, config_path: str, object_type: Optional[ObjectType]
) -> Optional[ConfigSource]:
found_source = None
for source in self.sources:
if source.exists(config_path):
found_source = source
break
if object_type == ObjectType.CONFIG:
if source.is_config(config_path):
found_source = source
break
elif object_type == ObjectType.GROUP:
if source.is_group(config_path):
found_source = source
break
else:
raise ValueError("Unexpected object_type")
return found_source

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions hydra/plugins/config_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ def _resolve_package(

return package

@staticmethod
def _update_package_in_header(
self,
header: Dict[str, str],
normalized_config_path: str,
is_primary_config: bool,
Expand Down Expand Up @@ -177,7 +177,7 @@ def _update_package_in_header(
# Hydra 1.1: default will change to _package_ and the warning will be removed.
header["package"] = "_global_"
msg = (
f"\nMissing @package directive in {normalized_config_path}.\n"
f"\nMissing @package directive {normalized_config_path} in {self.full_path()}.\n"
f"See https://hydra.cc/next/upgrades/0.11_to_1.0/package_header"
)
warnings.warn(message=msg, category=UserWarning)
Expand Down
1 change: 0 additions & 1 deletion hydra/test_utils/configs/config.yml
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
# @package _global_
# intentionally .yml and not .yaml
yml_file_here: true
1 change: 0 additions & 1 deletion hydra/test_utils/configs/custom_default_launcher.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
# @package _global_
defaults:
- hydra/launcher: fancy_launcher
1 change: 0 additions & 1 deletion hydra/test_utils/configs/db_conf.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
# @package _global_
defaults:
- db: mysql
1 change: 0 additions & 1 deletion hydra/test_utils/configs/defaults_not_list.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# @package _global_
defaults:
all: wrong
should_be: list
1 change: 0 additions & 1 deletion hydra/test_utils/configs/missing-default.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
# @package _global_
defaults:
- foo: file1
1 change: 0 additions & 1 deletion hydra/test_utils/configs/missing-optional-default.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# @package _global_
defaults:
- foo: missing
optional: true
1 change: 0 additions & 1 deletion hydra/test_utils/configs/mixed_compose.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# @package _global_
defaults:
- some_config
- group1: file1
Expand Down
1 change: 0 additions & 1 deletion hydra/test_utils/configs/non_config_group_default.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
# @package _global_
defaults:
- some_config
1 change: 0 additions & 1 deletion hydra/test_utils/configs/optional-default.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# @package _global_
defaults:
- group1: file1
optional: true
1 change: 0 additions & 1 deletion hydra/test_utils/configs/overriding_logging_default.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# @package _global_
defaults:
- hydra/launcher: null
- hydra/hydra_logging: hydra_debug
Expand Down
1 change: 0 additions & 1 deletion hydra/test_utils/configs/overriding_output_dir.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# @package _global_
hydra:
run:
dir: foo
1 change: 0 additions & 1 deletion hydra/test_utils/configs/overriding_run_dir.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# @package _global_
hydra:
run:
dir: cde
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
# @package _global_
defaults:
- hydra/launcher: null
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
# @package _global_
defaults:
- group1: ???
13 changes: 13 additions & 0 deletions hydra/test_utils/example_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from omegaconf import DictConfig

import hydra


@hydra.main(config_path="configs", config_name="db_conf")
def run_cli(cfg: DictConfig) -> None:
print(cfg.pretty())


if __name__ == "__main__":
run_cli()
4 changes: 1 addition & 3 deletions tests/test_config_loader.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

# TODO: bad error for:
# python examples/tutorials/basic/your_first_hydra_app/5_selecting_defaults_for_config_groups/my_app.py db=
# TODO : If not config file is specified, do not require + prefix to add items to defaults or config.
# TODO : If not configz`g file is specified, do not require + prefix to add items to defaults or config.
# TODO: Document command line:
# +/~, pacakges, defaults manipulation, the works.
# completion
Expand Down
4 changes: 2 additions & 2 deletions tests/test_config_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ def test_config_repository_exists(self, restore_singletons: Any, path: str) -> N
Plugins.instance() # initializes
config_search_path = create_config_search_path(path)
repo = ConfigRepository(config_search_path=config_search_path)
assert repo.exists("dataset/imagenet.yaml")
assert not repo.exists("not_found.yaml")
assert repo.config_exists("dataset/imagenet.yaml")
assert not repo.config_exists("not_found.yaml")

@pytest.mark.parametrize( # type: ignore
"config_path,results_filter,expected",
Expand Down
43 changes: 36 additions & 7 deletions tests/test_hydra.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import os
import re
import subprocess
import sys
from pathlib import Path
Expand Down Expand Up @@ -165,8 +166,8 @@ def test_app_with_config_file__no_overrides(
@pytest.mark.parametrize( # type: ignore
"calling_file, calling_module",
[
("tests/test_apps/app_with_cfg_groups_no_header/my_app.py", None),
(None, "tests.test_apps.app_with_cfg_groups_no_header.my_app"),
("tests/test_apps/app_with_cfg_groups_no_header/my_app.py", None,),
(None, "tests.test_apps.app_with_cfg_groups_no_header.my_app",),
],
)
def test_config_without_package_header_warnings(
Expand All @@ -187,12 +188,10 @@ def test_config_without_package_header_warnings(
"optimizer": {"type": "nesterov", "lr": 0.001}
}

msg = (
"\nMissing @package directive in optimizer/nesterov.yaml.\n"
"See https://hydra.cc/next/upgrades/0.11_to_1.0/package_header"
)
assert len(recwarn) == 1
assert recwarn.pop().message.args[0] == msg
msg = recwarn.pop().message.args[0]
assert "Missing @package directive optimizer/nesterov.yaml in " in msg
assert "See https://hydra.cc/next/upgrades/0.11_to_1.0/package_header" in msg


@pytest.mark.parametrize( # type: ignore
Expand Down Expand Up @@ -697,3 +696,33 @@ def test_hydra_env_set(tmpdir: Path) -> None:
prints="os.environ['foo']",
expected_outputs="bar",
)


@pytest.mark.parametrize( # type: ignore
"override", [pytest.param("xyz", id="db=xyz"), pytest.param("", id="db=")]
)
@pytest.mark.parametrize( # type: ignore
"calling_file, calling_module",
[
pytest.param("hydra/test_utils/example_app.py", None, id="file"),
pytest.param(None, "hydra.test_utils.example_app", id="module"),
],
)
def test_override_with_invalid_group_choice(
restore_singletons: Any,
task_runner: TTaskRunner,
calling_file: str,
calling_module: str,
override: str,
) -> None:
msg = f"""Could not load db/{override}, available options:\ndb:\n\tmysql\n\tpostgresql"""

with pytest.raises(MissingConfigException, match=re.escape(msg)):
with task_runner(
calling_file=calling_file,
calling_module=calling_module,
config_path="configs",
config_name="db_conf",
overrides=[f"db={override}"],
):
...

0 comments on commit c5390ef

Please sign in to comment.