Skip to content

Commit

Permalink
Support for configuring the config searchpath
Browse files Browse the repository at this point in the history
This is only allowed from the primary config.
  • Loading branch information
omry committed Mar 4, 2021
1 parent 9371a29 commit 887c38f
Show file tree
Hide file tree
Showing 7 changed files with 236 additions and 12 deletions.
74 changes: 71 additions & 3 deletions hydra/_internal/config_loader_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from collections import defaultdict
from dataclasses import dataclass
from textwrap import dedent
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, MutableSequence, Optional

from omegaconf import DictConfig, OmegaConf, flag_override, open_dict
from omegaconf import Container, DictConfig, OmegaConf, flag_override, open_dict
from omegaconf.errors import (
ConfigAttributeError,
ConfigKeyError,
Expand Down Expand Up @@ -152,6 +152,62 @@ def load_configuration(
except OmegaConfBaseException as e:
raise ConfigCompositionException().with_traceback(sys.exc_info()[2]) from e

def _process_config_searchpath(
self,
config_name: Optional[str],
parsed_overrides: List[Override],
repo: CachingConfigRepository,
) -> None:
if config_name is not None:
loaded = repo.load_config(config_path=config_name)
primary_config: Container
if loaded is None:
primary_config = OmegaConf.create()
else:
primary_config = loaded.config
else:
primary_config = OmegaConf.create()

def is_searchpath_override(v: Override) -> bool:
return v.get_key_element() == "hydra.searchpath"

override = None
for v in parsed_overrides:
if is_searchpath_override(v):
override = v.value()
break

searchpath = OmegaConf.select(primary_config, "hydra.searchpath")
if override is not None:
provider = "hydra.searchpath in command-line"
searchpath = override
else:
provider = "hydra.searchpath in main"

def _err() -> None:
raise ConfigCompositionException(
f"hydra.searchpath must be a list of strings. Got: {searchpath}"
)

if searchpath is None:
return

# validate hydra.searchpath.
# Note that we cannot rely on OmegaConf validation here because we did not yet merge with the Hydra schema node
if not isinstance(searchpath, MutableSequence):
_err()
for v in searchpath:
if not isinstance(v, str):
_err()

new_csp = copy.deepcopy(self.config_search_path)
schema = new_csp.get_path().pop(-1)
assert schema.provider == "schema"
for sp in searchpath:
new_csp.append(provider=provider, path=sp)
new_csp.append("schema", "structured://")
repo.initialize_sources(new_csp)

def _load_configuration_impl(
self,
config_name: Optional[str],
Expand All @@ -165,6 +221,8 @@ def _load_configuration_impl(
parser = OverridesParser.create()
parsed_overrides = parser.parse_overrides(overrides=overrides)

self._process_config_searchpath(config_name, parsed_overrides, caching_repo)

self.validate_sweep_overrides_legal(
overrides=parsed_overrides, run_mode=run_mode, from_shell=from_shell
)
Expand Down Expand Up @@ -401,7 +459,17 @@ def _load_single_config(

assert isinstance(merged, DictConfig)

return self._embed_result_config(ret, default.package)
res = self._embed_result_config(ret, default.package)
if (
not default.primary
and config_path != "hydra/config"
and OmegaConf.select(res.config, "hydra.searchpath") is not None
):
raise ConfigCompositionException(
f"In '{config_path}': Overriding hydra.searchpath is only supported from the primary config"
)

return res

@staticmethod
def _embed_result_config(
Expand Down
15 changes: 14 additions & 1 deletion hydra/_internal/config_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,20 @@ def get_group_options(
def get_sources(self) -> List[ConfigSource]:
...

@abstractmethod
def initialize_sources(self, config_search_path: ConfigSearchPath) -> None:
...


class ConfigRepository(IConfigRepository):

config_search_path: ConfigSearchPath
sources: List[ConfigSource]

def __init__(self, config_search_path: ConfigSearchPath) -> None:
self.initialize_sources(config_search_path)

def initialize_sources(self, config_search_path: ConfigSearchPath) -> None:
self.sources = []
for search_path in config_search_path.get_path():
assert search_path.path is not None
Expand All @@ -71,7 +78,7 @@ def get_schema_source(self) -> ConfigSource:
assert (
source.__class__.__name__ == "StructuredConfigSource"
and source.provider == "schema"
)
), "schema config source must be last"
return source

def load_config(self, config_path: str) -> Optional[ConfigResult]:
Expand Down Expand Up @@ -313,6 +320,12 @@ def __init__(self, delegate: IConfigRepository):
def get_schema_source(self) -> ConfigSource:
return self.delegate.get_schema_source()

def initialize_sources(self, config_search_path: ConfigSearchPath) -> None:
self.delegate.initialize_sources(config_search_path)
# not clearing the cache.
# For the use case this is used, the only thing in the cache is the primary config
# and we want to keep it even though we re-initialized the sources.

def load_config(self, config_path: str) -> Optional[ConfigResult]:
cache_key = f"config_path={config_path}"
if cache_key in self.cache:
Expand Down
4 changes: 2 additions & 2 deletions hydra/_internal/config_search_path_impl.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import List, Optional, Sequence
from typing import List, MutableSequence, Optional

from hydra.core.config_search_path import (
ConfigSearchPath,
Expand All @@ -14,7 +14,7 @@ class ConfigSearchPathImpl(ConfigSearchPath):
def __init__(self) -> None:
self.config_search_path = []

def get_path(self) -> Sequence[SearchPathElement]:
def get_path(self) -> MutableSequence[SearchPathElement]:
return self.config_search_path

def find_last_match(self, reference: SearchPathQuery) -> int:
Expand Down
4 changes: 4 additions & 0 deletions hydra/conf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ class HydraConf:
]
)

# Elements to append to the config search path.
# Note: This can only be configured in the primary config.
searchpath: List[str] = field(default_factory=list)

# Normal run output configuration
run: RunDir = RunDir()
# Multi-run output configuration
Expand Down
4 changes: 2 additions & 2 deletions hydra/core/config_search_path.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional, Sequence
from typing import MutableSequence, Optional


class SearchPathElement:
Expand All @@ -28,7 +28,7 @@ class SearchPathQuery:

class ConfigSearchPath(ABC):
@abstractmethod
def get_path(self) -> Sequence[SearchPathElement]:
def get_path(self) -> MutableSequence[SearchPathElement]:
...

@abstractmethod
Expand Down
1 change: 1 addition & 0 deletions news/274.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Support for configuring the config search path from the primary config
146 changes: 142 additions & 4 deletions tests/test_compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Any, Dict, List, Optional

from omegaconf import OmegaConf
from pytest import mark, param, raises
from pytest import fixture, mark, param, raises

from hydra._internal.config_search_path_impl import ConfigSearchPathImpl
from hydra.core.config_search_path import SearchPathQuery
Expand Down Expand Up @@ -319,8 +319,9 @@ class Config:
assert cfg == expected


@mark.usefixtures("hydra_restore_singletons")
class TestAdd:
def test_add(self, hydra_restore_singletons: Any) -> None:
def test_add(self) -> None:
ConfigStore.instance().store(name="config", node={"key": 0})
with initialize():
with raises(
Expand All @@ -332,7 +333,7 @@ def test_add(self, hydra_restore_singletons: Any) -> None:
cfg = compose(config_name="config", overrides=["key=1"])
assert cfg == {"key": 1}

def test_force_add(self, hydra_restore_singletons: Any) -> None:
def test_force_add(self) -> None:
ConfigStore.instance().store(name="config", node={"key": 0})
with initialize():
cfg = compose(config_name="config", overrides=["++key=1"])
Expand All @@ -341,7 +342,7 @@ def test_force_add(self, hydra_restore_singletons: Any) -> None:
cfg = compose(config_name="config", overrides=["++key2=1"])
assert cfg == {"key": 0, "key2": 1}

def test_add_config_group(self, hydra_restore_singletons: Any) -> None:
def test_add_config_group(self) -> None:
ConfigStore.instance().store(group="group", name="a0", node={"key": 0})
ConfigStore.instance().store(group="group", name="a1", node={"key": 1})
with initialize():
Expand All @@ -361,3 +362,140 @@ def test_add_config_group(self, hydra_restore_singletons: Any) -> None:
),
):
compose(overrides=["++group=a1"])


@mark.usefixtures("hydra_restore_singletons")
class TestConfigSearchPathOverride:
@fixture
def init_configs(self) -> Any:
cs = ConfigStore.instance()
cs.store(
name="with_sp",
node={"hydra": {"searchpath": ["pkg://hydra.test_utils.configs"]}},
)
cs.store(name="without_sp", node={})

cs.store(name="bad1", node={"hydra": {"searchpath": 42}})
cs.store(name="bad2", node={"hydra": {"searchpath": [42]}})

# Using this triggers an error. Only primary configs are allowed to override hydra.searchpath
cs.store(
group="group2",
name="overriding_sp",
node={"hydra": {"searchpath": ["abc"]}},
package="_global_",
)
yield

@fixture
def initialize_hydra(self) -> Any:
try:
init = initialize()
init.__enter__()
yield
finally:
init.__exit__(*sys.exc_info())

@mark.parametrize(
("config_name", "overrides", "expected"),
[
# config group is interpreted as simple config value addition.
param("without_sp", ["+group1=file1"], {"group1": "file1"}, id="without"),
param("with_sp", ["+group1=file1"], {"foo": 10}, id="with"),
# Overriding hydra.searchpath
param(
"without_sp",
["hydra.searchpath=[pkg://hydra.test_utils.configs]", "+group1=file1"],
{"foo": 10},
id="sp_added_by_override",
),
param(
"with_sp",
["hydra.searchpath=[]", "+group1=file1"],
{"group1": "file1"},
id="sp_removed_by_override",
),
],
)
def test_searchpath_in_primary_config(
self,
initialize_hydra: Any,
init_configs: Any,
config_name: str,
overrides: List[str],
expected: Any,
) -> None:
cfg = compose(config_name=config_name, overrides=overrides)
assert cfg == expected

@mark.parametrize(
("config_name", "overrides", "expected"),
[
param(
"bad1",
[],
raises(
ConfigCompositionException,
match=re.escape(
"hydra.searchpath must be a list of strings. Got: 42"
),
),
id="bad_cp_in_config",
),
param(
"bad2",
[],
raises(
ConfigCompositionException,
match=re.escape(
"hydra.searchpath must be a list of strings. Got: [42]"
),
),
id="bad_cp_element_in_config",
),
param(
"without_sp",
["hydra.searchpath=42"],
raises(
ConfigCompositionException,
match=re.escape(
"hydra.searchpath must be a list of strings. Got: 42"
),
),
id="bad_override1",
),
param(
"without_sp",
["hydra.searchpath=[42]"],
raises(
ConfigCompositionException,
match=re.escape(
"hydra.searchpath must be a list of strings. Got: [42]"
),
),
id="bad_override2",
),
param(
"without_sp",
["+group2=overriding_sp"],
raises(
ConfigCompositionException,
match=re.escape(
"In 'group2/overriding_sp': Overriding hydra.searchpath "
"is only supported from the primary config"
),
),
id="overriding_sp_from_non_primary_config",
),
],
)
def test_searchpath_config_errors(
self,
initialize_hydra: Any,
init_configs: Any,
config_name: str,
overrides: List[str],
expected: Any,
) -> None:
with expected:
compose(config_name=config_name, overrides=overrides)

0 comments on commit 887c38f

Please sign in to comment.