Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
omry committed Feb 18, 2020
1 parent fff45cc commit e93878e
Show file tree
Hide file tree
Showing 7 changed files with 143 additions and 19 deletions.
2 changes: 1 addition & 1 deletion hydra/_internal/config_loader_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def _load_config_impl(
schema = ConfigStore.instance().load(
config_path=ConfigSource._normalize_file_name(filename=input_file)
)
merged = OmegaConf.merge(schema, ret.config)
merged = OmegaConf.merge(schema.node, ret.config)
assert isinstance(merged, DictConfig)
return merged
except ConfigLoadError:
Expand Down
8 changes: 4 additions & 4 deletions hydra/_internal/core_plugins/structured_config_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import warnings
from typing import List, Optional

from hydra.core.config_store import ConfigStore
from hydra.core.config_store import ConfigStore, ConfigNode
from hydra.core.object_type import ObjectType
from hydra.plugins.config_source import ConfigResult, ConfigSource

Expand All @@ -30,10 +30,10 @@ def scheme() -> str:

def load_config(self, config_path: str) -> ConfigResult:
full_path = self._normalize_file_name(config_path)
ret = self.store.load(config_path=full_path)
provider = ret.provider if ret.provider is not None else self.provider
return ConfigResult(
config=self.store.load(config_path=full_path),
path=f"{self.scheme()}://{self.path}",
provider=self.provider,
config=ret.node, path=f"{self.scheme()}://{self.path}", provider=provider
)

def is_group(self, config_path: str) -> bool:
Expand Down
3 changes: 2 additions & 1 deletion hydra/conf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class RuntimeConf:


@dataclass
class HydraConf(Dict[str, Any]):
class HydraConf:
# Normal run output configuration
run: RunDir = RunDir()
# Multi-run output configuration
Expand Down Expand Up @@ -158,4 +158,5 @@ class HydraConf(Dict[str, Any]):
# Hydra config
"hydra": HydraConf,
},
provider="hydra",
)
61 changes: 55 additions & 6 deletions hydra/core/config_store.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,57 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import Any, Dict, List, Optional

from dataclasses import dataclass
from omegaconf import DictConfig, OmegaConf

from hydra.core.object_type import ObjectType
from hydra.core.singleton import Singleton
from hydra.plugins.config_source import ConfigLoadError


class ConfigStoreWithProvider:
def __init__(self, provider: str) -> None:
self.provider = provider

def __enter__(self) -> "ConfigStoreWithProvider":
return self

def store(
self,
name: str,
node: Any,
group: Optional[str] = None,
path: Optional[str] = None,
) -> None:
ConfigStore.instance().store(
group=group, name=name, node=node, path=path, provider=self.provider
)

def __exit__(self, exc_type: Any, exc_value: Any, exc_traceback: Any) -> Any:
...


class ConfigNode:
name: str
node: Any
group: Optional[str]
path: Optional[str]
provider: Optional[str]

def __init__(
self,
name: str,
node: Any,
group: Optional[str],
path: Optional[str],
provider: Optional[str],
):
self.name = name
self.node = node
self.group = group
self.path = path
self.provider = provider


class ConfigStore(metaclass=Singleton):
@staticmethod
def instance(*args: Any, **kwargs: Any) -> "ConfigStore":
Expand All @@ -24,13 +68,15 @@ def store(
node: Any,
group: Optional[str] = None,
path: Optional[str] = None,
provider: Optional[str] = None,
) -> None:
"""
Stores a config node into the repository
:param name: config name
:param node: config node, can be DictConfig, ListConfig, Structured configs and even dict and list
:param group: config group, subgroup separator is '/', for example hydra/launcher
:param path: Config node parent hierarchy. child separator is '.', for example foo.bar.bazz
:param path: Config node parent hierarchy. child separator is '.', for example foo.bar.baz
:param provider: the name of the module/app providing this config. Helps debugging.
"""
cur = self.repo
if group is not None:
Expand All @@ -47,16 +93,19 @@ def store(

if not name.endswith(".yaml"):
name = f"{name}.yaml"
cur[name] = cfg
assert isinstance(cur, dict)
cur[name] = ConfigNode(
name=name, node=cfg, group=group, path=path, provider=provider
)

def load(self, config_path: str) -> DictConfig:
def load(self, config_path: str) -> ConfigNode:
idx = config_path.rfind("/")
if idx == -1:
ret = self._open(config_path)
if ret is None:
raise ConfigLoadError(f"Structured config not found {config_path}")

assert isinstance(ret, DictConfig)
assert isinstance(ret, ConfigNode)
return ret
else:
path = config_path[0:idx]
Expand All @@ -71,7 +120,7 @@ def load(self, config_path: str) -> DictConfig:
)

ret = d[name]
assert isinstance(ret, DictConfig)
assert isinstance(ret, ConfigNode)
return ret

def get_type(self, path: str) -> ObjectType:
Expand Down
12 changes: 6 additions & 6 deletions tests/test_config_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def test_load_history(self, path: str) -> None:
config_name="missing-optional-default.yaml", overrides=[], strict=False
)
assert config_loader.get_load_history() == [
("hydra_config", "structured://", "schema"),
("hydra_config", "structured://", "hydra"),
("hydra/hydra_logging/default", "pkg://hydra.conf", "hydra"),
("hydra/job_logging/default", "pkg://hydra.conf", "hydra"),
("hydra/launcher/basic", "pkg://hydra.conf", "hydra"),
Expand All @@ -162,7 +162,7 @@ def test_load_history_with_basic_launcher(self, path: str) -> None:
)

assert config_loader.get_load_history() == [
("hydra_config", "structured://", "schema"),
("hydra_config", "structured://", "hydra"),
("hydra/hydra_logging/default", "pkg://hydra.conf", "hydra"),
("hydra/job_logging/default", "pkg://hydra.conf", "hydra"),
("hydra/launcher/basic", "pkg://hydra.conf", "hydra"),
Expand Down Expand Up @@ -310,7 +310,7 @@ def test_default_removal(config_file: str, overrides: List[str]) -> None:
config_name=config_file, overrides=overrides, strict=False
)
assert config_loader.get_load_history() == [
("hydra_config", "structured://", "schema"),
("hydra_config", "structured://", "hydra"),
("hydra/hydra_logging/default", "pkg://hydra.conf", "hydra"),
("hydra/job_logging/default", "pkg://hydra.conf", "hydra"),
("hydra/sweeper/basic", "pkg://hydra.conf", "hydra"),
Expand Down Expand Up @@ -368,7 +368,7 @@ def test_override_hydra_config_group_from_config_file() -> None:
config_name="overriding_logging_default.yaml", overrides=[], strict=False
)
assert config_loader.get_load_history() == [
("hydra_config", "structured://", "schema"),
("hydra_config", "structured://", "hydra"),
("hydra/hydra_logging/hydra_debug", "pkg://hydra.conf", "hydra"),
("hydra/job_logging/disabled", "pkg://hydra.conf", "hydra"),
("hydra/sweeper/basic", "pkg://hydra.conf", "hydra"),
Expand Down Expand Up @@ -413,7 +413,7 @@ def test_non_config_group_default() -> None:
config_name="non_config_group_default.yaml", overrides=[], strict=False
)
assert config_loader.get_load_history() == [
("hydra_config", "structured://", "schema"),
("hydra_config", "structured://", "hydra"),
("hydra/hydra_logging/default", "pkg://hydra.conf", "hydra"),
("hydra/job_logging/default", "pkg://hydra.conf", "hydra"),
("hydra/launcher/basic", "pkg://hydra.conf", "hydra"),
Expand All @@ -438,7 +438,7 @@ def test_mixed_composition_order() -> None:
config_name="mixed_compose.yaml", overrides=[], strict=False
)
assert config_loader.get_load_history() == [
("hydra_config", "structured://", "schema"),
("hydra_config", "structured://", "hydra"),
("hydra/hydra_logging/default", "pkg://hydra.conf", "hydra"),
("hydra/job_logging/default", "pkg://hydra.conf", "hydra"),
("hydra/launcher/basic", "pkg://hydra.conf", "hydra"),
Expand Down
72 changes: 72 additions & 0 deletions tests/test_structured_config_loading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from dataclasses import dataclass
from typing import Any, List

import pkg_resources
import pytest
from omegaconf import MISSING, ListConfig, OmegaConf, ValidationError

from hydra._internal.config_loader_impl import ConfigLoaderImpl
from hydra._internal.utils import create_config_search_path
from hydra.core.config_store import ConfigStore, ConfigStoreWithProvider
from hydra.errors import MissingConfigException
from hydra.test_utils.test_utils import ( # noqa: F401
chdir_hydra_root,
restore_singletons,
)
from typing import Any, Optional

chdir_hydra_root()


@dataclass
class MySQLConfig:
driver: str = MISSING
host: str = MISSING
port: int = MISSING
user: str = MISSING
password: str = MISSING


hydra_load_list = [
("hydra_config", "structured://", "hydra"),
("hydra/hydra_logging/default", "pkg://hydra.conf", "hydra"),
("hydra/job_logging/default", "pkg://hydra.conf", "hydra"),
("hydra/launcher/basic", "pkg://hydra.conf", "hydra"),
("hydra/sweeper/basic", "pkg://hydra.conf", "hydra"),
("hydra/output/default", "pkg://hydra.conf", "hydra"),
("hydra/help/default", "pkg://hydra.conf", "hydra"),
("hydra/hydra_help/default", "pkg://hydra.conf", "hydra"),
]


def test_load_as_configuration(restore_singletons) -> None:
"""
Load structured config as a configuration
"""
with ConfigStoreWithProvider("test_provider") as config_store:
config_store.store(group="db", name="mysql", node=MySQLConfig, path="db")

config_loader = ConfigLoaderImpl(config_search_path=create_config_search_path(None))
cfg = config_loader.load_configuration(config_name="db/mysql", overrides=[])
del cfg["hydra"]
assert cfg == {
"db": {
"driver": MISSING,
"host": MISSING,
"port": MISSING,
"user": MISSING,
"password": MISSING,
}
}

expected = hydra_load_list.copy()
expected.extend([("db/mysql", "structured://", "test_provider")])
assert config_loader.get_load_history() == expected


@pytest.mark.parametrize(
"path", ["file://hydra/test_utils/configs", "pkg://hydra.test_utils.configs"]
)
class TestConfigLoader:
pass
4 changes: 3 additions & 1 deletion website/docs/tutorials/structured_config/10_config_store.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@ class ConfigStore(metaclass=Singleton):
node: Any,
group: Optional[str] = None,
path: Optional[str] = None,
provider: Optional[str] = None,
) -> None:
"""
Stores a config node into the repository
:param name: config name
:param node: config node, can be DictConfig, ListConfig, Structured configs and even dict and list
:param group: config group, subgroup separator is '/', for example hydra/launcher
:param path: Config node parent hierarchy. child separator is '.', for example foo.bar.bazz
:param path: Config node parent hierarchy. child separator is '.', for example foo.bar.baz
:param provider: the name of the module/app providing this config. Helps debugging.
"""
...
```

0 comments on commit e93878e

Please sign in to comment.