From e93878ee9a4d1eb0760e5f1f8973494d4c37763a Mon Sep 17 00:00:00 2001 From: Omry Yadan Date: Mon, 17 Feb 2020 19:22:55 -0800 Subject: [PATCH] wip --- hydra/_internal/config_loader_impl.py | 2 +- .../core_plugins/structured_config_source.py | 8 +-- hydra/conf/__init__.py | 3 +- hydra/core/config_store.py | 61 ++++++++++++++-- tests/test_config_loader.py | 12 ++-- tests/test_structured_config_loading.py | 72 +++++++++++++++++++ .../structured_config/10_config_store.md | 4 +- 7 files changed, 143 insertions(+), 19 deletions(-) create mode 100644 tests/test_structured_config_loading.py diff --git a/hydra/_internal/config_loader_impl.py b/hydra/_internal/config_loader_impl.py index 569ca9150ae..9d933fc9eb1 100644 --- a/hydra/_internal/config_loader_impl.py +++ b/hydra/_internal/config_loader_impl.py @@ -250,7 +250,7 @@ def _load_config_impl( schema = ConfigStore.instance().load( config_path=ConfigSource._normalize_file_name(filename=input_file) ) - merged = OmegaConf.merge(schema, ret.config) + merged = OmegaConf.merge(schema.node, ret.config) assert isinstance(merged, DictConfig) return merged except ConfigLoadError: diff --git a/hydra/_internal/core_plugins/structured_config_source.py b/hydra/_internal/core_plugins/structured_config_source.py index aa8bdd77b68..2a427656f6e 100644 --- a/hydra/_internal/core_plugins/structured_config_source.py +++ b/hydra/_internal/core_plugins/structured_config_source.py @@ -3,7 +3,7 @@ import warnings from typing import List, Optional -from hydra.core.config_store import ConfigStore +from hydra.core.config_store import ConfigStore, ConfigNode from hydra.core.object_type import ObjectType from hydra.plugins.config_source import ConfigResult, ConfigSource @@ -30,10 +30,10 @@ def scheme() -> str: def load_config(self, config_path: str) -> ConfigResult: full_path = self._normalize_file_name(config_path) + ret = self.store.load(config_path=full_path) + provider = ret.provider if ret.provider is not None else self.provider return ConfigResult( - config=self.store.load(config_path=full_path), - path=f"{self.scheme()}://{self.path}", - provider=self.provider, + config=ret.node, path=f"{self.scheme()}://{self.path}", provider=provider ) def is_group(self, config_path: str) -> bool: diff --git a/hydra/conf/__init__.py b/hydra/conf/__init__.py index d5d37c6e148..16c59e9c976 100644 --- a/hydra/conf/__init__.py +++ b/hydra/conf/__init__.py @@ -106,7 +106,7 @@ class RuntimeConf: @dataclass -class HydraConf(Dict[str, Any]): +class HydraConf: # Normal run output configuration run: RunDir = RunDir() # Multi-run output configuration @@ -158,4 +158,5 @@ class HydraConf(Dict[str, Any]): # Hydra config "hydra": HydraConf, }, + provider="hydra", ) diff --git a/hydra/core/config_store.py b/hydra/core/config_store.py index 8cdb099e07d..7b2af8a6446 100644 --- a/hydra/core/config_store.py +++ b/hydra/core/config_store.py @@ -1,6 +1,6 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved from typing import Any, Dict, List, Optional - +from dataclasses import dataclass from omegaconf import DictConfig, OmegaConf from hydra.core.object_type import ObjectType @@ -8,6 +8,50 @@ from hydra.plugins.config_source import ConfigLoadError +class ConfigStoreWithProvider: + def __init__(self, provider: str) -> None: + self.provider = provider + + def __enter__(self) -> "ConfigStoreWithProvider": + return self + + def store( + self, + name: str, + node: Any, + group: Optional[str] = None, + path: Optional[str] = None, + ) -> None: + ConfigStore.instance().store( + group=group, name=name, node=node, path=path, provider=self.provider + ) + + def __exit__(self, exc_type: Any, exc_value: Any, exc_traceback: Any) -> Any: + ... + + +class ConfigNode: + name: str + node: Any + group: Optional[str] + path: Optional[str] + provider: Optional[str] + + def __init__( + self, + name: str, + node: Any, + group: Optional[str], + path: Optional[str], + provider: Optional[str], + ): + self.name = name + self.node = node + self.group = group + self.path = path + self.provider = provider + + class ConfigStore(metaclass=Singleton): @staticmethod def instance(*args: Any, **kwargs: Any) -> "ConfigStore": @@ -24,13 +68,15 @@ def store( node: Any, group: Optional[str] = None, path: Optional[str] = None, + provider: Optional[str] = None, ) -> None: """ Stores a config node into the repository :param name: config name :param node: config node, can be DictConfig, ListConfig, Structured configs and even dict and list :param group: config group, subgroup separator is '/', for example hydra/launcher - :param path: Config node parent hierarchy. child separator is '.', for example foo.bar.bazz + :param path: Config node parent hierarchy. child separator is '.', for example foo.bar.baz + :param provider: the name of the module/app providing this config. Helps debugging. """ cur = self.repo if group is not None: @@ -47,16 +93,19 @@ def store( if not name.endswith(".yaml"): name = f"{name}.yaml" - cur[name] = cfg + assert isinstance(cur, dict) + cur[name] = ConfigNode( + name=name, node=cfg, group=group, path=path, provider=provider + ) - def load(self, config_path: str) -> DictConfig: + def load(self, config_path: str) -> ConfigNode: idx = config_path.rfind("/") if idx == -1: ret = self._open(config_path) if ret is None: raise ConfigLoadError(f"Structured config not found {config_path}") - assert isinstance(ret, DictConfig) + assert isinstance(ret, ConfigNode) return ret else: path = config_path[0:idx] @@ -71,7 +120,7 @@ def load(self, config_path: str) -> DictConfig: ) ret = d[name] - assert isinstance(ret, DictConfig) + assert isinstance(ret, ConfigNode) return ret def get_type(self, path: str) -> ObjectType: diff --git a/tests/test_config_loader.py b/tests/test_config_loader.py index b048de4410b..6d8bba1736d 100644 --- a/tests/test_config_loader.py +++ b/tests/test_config_loader.py @@ -139,7 +139,7 @@ def test_load_history(self, path: str) -> None: config_name="missing-optional-default.yaml", overrides=[], strict=False ) assert config_loader.get_load_history() == [ - ("hydra_config", "structured://", "schema"), + ("hydra_config", "structured://", "hydra"), ("hydra/hydra_logging/default", "pkg://hydra.conf", "hydra"), ("hydra/job_logging/default", "pkg://hydra.conf", "hydra"), ("hydra/launcher/basic", "pkg://hydra.conf", "hydra"), @@ -162,7 +162,7 @@ def test_load_history_with_basic_launcher(self, path: str) -> None: ) assert config_loader.get_load_history() == [ - ("hydra_config", "structured://", "schema"), + ("hydra_config", "structured://", "hydra"), ("hydra/hydra_logging/default", "pkg://hydra.conf", "hydra"), ("hydra/job_logging/default", "pkg://hydra.conf", "hydra"), ("hydra/launcher/basic", "pkg://hydra.conf", "hydra"), @@ -310,7 +310,7 @@ def test_default_removal(config_file: str, overrides: List[str]) -> None: config_name=config_file, overrides=overrides, strict=False ) assert config_loader.get_load_history() == [ - ("hydra_config", "structured://", "schema"), + ("hydra_config", "structured://", "hydra"), ("hydra/hydra_logging/default", "pkg://hydra.conf", "hydra"), ("hydra/job_logging/default", "pkg://hydra.conf", "hydra"), ("hydra/sweeper/basic", "pkg://hydra.conf", "hydra"), @@ -368,7 +368,7 @@ def test_override_hydra_config_group_from_config_file() -> None: config_name="overriding_logging_default.yaml", overrides=[], strict=False ) assert config_loader.get_load_history() == [ - ("hydra_config", "structured://", "schema"), + ("hydra_config", "structured://", "hydra"), ("hydra/hydra_logging/hydra_debug", "pkg://hydra.conf", "hydra"), ("hydra/job_logging/disabled", "pkg://hydra.conf", "hydra"), ("hydra/sweeper/basic", "pkg://hydra.conf", "hydra"), @@ -413,7 +413,7 @@ def test_non_config_group_default() -> None: config_name="non_config_group_default.yaml", overrides=[], strict=False ) assert config_loader.get_load_history() == [ - ("hydra_config", "structured://", "schema"), + ("hydra_config", "structured://", "hydra"), ("hydra/hydra_logging/default", "pkg://hydra.conf", "hydra"), ("hydra/job_logging/default", "pkg://hydra.conf", "hydra"), ("hydra/launcher/basic", "pkg://hydra.conf", "hydra"), @@ -438,7 +438,7 @@ def test_mixed_composition_order() -> None: config_name="mixed_compose.yaml", overrides=[], strict=False ) assert config_loader.get_load_history() == [ - ("hydra_config", "structured://", "schema"), + ("hydra_config", "structured://", "hydra"), ("hydra/hydra_logging/default", "pkg://hydra.conf", "hydra"), ("hydra/job_logging/default", "pkg://hydra.conf", "hydra"), ("hydra/launcher/basic", "pkg://hydra.conf", "hydra"), diff --git a/tests/test_structured_config_loading.py b/tests/test_structured_config_loading.py new file mode 100644 index 00000000000..803698f5b58 --- /dev/null +++ b/tests/test_structured_config_loading.py @@ -0,0 +1,72 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from dataclasses import dataclass +from typing import Any, List + +import pkg_resources +import pytest +from omegaconf import MISSING, ListConfig, OmegaConf, ValidationError + +from hydra._internal.config_loader_impl import ConfigLoaderImpl +from hydra._internal.utils import create_config_search_path +from hydra.core.config_store import ConfigStore, ConfigStoreWithProvider +from hydra.errors import MissingConfigException +from hydra.test_utils.test_utils import ( # noqa: F401 + chdir_hydra_root, + restore_singletons, +) +from typing import Any, Optional + +chdir_hydra_root() + + +@dataclass +class MySQLConfig: + driver: str = MISSING + host: str = MISSING + port: int = MISSING + user: str = MISSING + password: str = MISSING + + +hydra_load_list = [ + ("hydra_config", "structured://", "hydra"), + ("hydra/hydra_logging/default", "pkg://hydra.conf", "hydra"), + ("hydra/job_logging/default", "pkg://hydra.conf", "hydra"), + ("hydra/launcher/basic", "pkg://hydra.conf", "hydra"), + ("hydra/sweeper/basic", "pkg://hydra.conf", "hydra"), + ("hydra/output/default", "pkg://hydra.conf", "hydra"), + ("hydra/help/default", "pkg://hydra.conf", "hydra"), + ("hydra/hydra_help/default", "pkg://hydra.conf", "hydra"), +] + + +def test_load_as_configuration(restore_singletons) -> None: + """ + Load structured config as a configuration + """ + with ConfigStoreWithProvider("test_provider") as config_store: + config_store.store(group="db", name="mysql", node=MySQLConfig, path="db") + + config_loader = ConfigLoaderImpl(config_search_path=create_config_search_path(None)) + cfg = config_loader.load_configuration(config_name="db/mysql", overrides=[]) + del cfg["hydra"] + assert cfg == { + "db": { + "driver": MISSING, + "host": MISSING, + "port": MISSING, + "user": MISSING, + "password": MISSING, + } + } + + expected = hydra_load_list.copy() + expected.extend([("db/mysql", "structured://", "test_provider")]) + assert config_loader.get_load_history() == expected + + +@pytest.mark.parametrize( + "path", ["file://hydra/test_utils/configs", "pkg://hydra.test_utils.configs"] +) +class TestConfigLoader: + pass diff --git a/website/docs/tutorials/structured_config/10_config_store.md b/website/docs/tutorials/structured_config/10_config_store.md index 9feebf7b955..135147a7fd2 100644 --- a/website/docs/tutorials/structured_config/10_config_store.md +++ b/website/docs/tutorials/structured_config/10_config_store.md @@ -13,13 +13,15 @@ class ConfigStore(metaclass=Singleton): node: Any, group: Optional[str] = None, path: Optional[str] = None, + provider: Optional[str] = None, ) -> None: """ Stores a config node into the repository :param name: config name :param node: config node, can be DictConfig, ListConfig, Structured configs and even dict and list :param group: config group, subgroup separator is '/', for example hydra/launcher - :param path: Config node parent hierarchy. child separator is '.', for example foo.bar.bazz + :param path: Config node parent hierarchy. child separator is '.', for example foo.bar.baz + :param provider: the name of the module/app providing this config. Helps debugging. """ ... ``` \ No newline at end of file