From 887c38f36c3b56645d91116a5897240091620274 Mon Sep 17 00:00:00 2001 From: Omry Yadan Date: Wed, 3 Mar 2021 17:40:39 -0800 Subject: [PATCH] Support for configuring the config searchpath This is only allowed from the primary config. --- hydra/_internal/config_loader_impl.py | 74 ++++++++++- hydra/_internal/config_repository.py | 15 ++- hydra/_internal/config_search_path_impl.py | 4 +- hydra/conf/__init__.py | 4 + hydra/core/config_search_path.py | 4 +- news/274.feature | 1 + tests/test_compose.py | 146 ++++++++++++++++++++- 7 files changed, 236 insertions(+), 12 deletions(-) create mode 100644 news/274.feature diff --git a/hydra/_internal/config_loader_impl.py b/hydra/_internal/config_loader_impl.py index 75d70f27f3a..a596d83a199 100644 --- a/hydra/_internal/config_loader_impl.py +++ b/hydra/_internal/config_loader_impl.py @@ -8,9 +8,9 @@ from collections import defaultdict from dataclasses import dataclass from textwrap import dedent -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, MutableSequence, Optional -from omegaconf import DictConfig, OmegaConf, flag_override, open_dict +from omegaconf import Container, DictConfig, OmegaConf, flag_override, open_dict from omegaconf.errors import ( ConfigAttributeError, ConfigKeyError, @@ -152,6 +152,62 @@ def load_configuration( except OmegaConfBaseException as e: raise ConfigCompositionException().with_traceback(sys.exc_info()[2]) from e + def _process_config_searchpath( + self, + config_name: Optional[str], + parsed_overrides: List[Override], + repo: CachingConfigRepository, + ) -> None: + if config_name is not None: + loaded = repo.load_config(config_path=config_name) + primary_config: Container + if loaded is None: + primary_config = OmegaConf.create() + else: + primary_config = loaded.config + else: + primary_config = OmegaConf.create() + + def is_searchpath_override(v: Override) -> bool: + return v.get_key_element() == "hydra.searchpath" + + override = None + for v in parsed_overrides: + if is_searchpath_override(v): + override = v.value() + break + + searchpath = OmegaConf.select(primary_config, "hydra.searchpath") + if override is not None: + provider = "hydra.searchpath in command-line" + searchpath = override + else: + provider = "hydra.searchpath in main" + + def _err() -> None: + raise ConfigCompositionException( + f"hydra.searchpath must be a list of strings. Got: {searchpath}" + ) + + if searchpath is None: + return + + # validate hydra.searchpath. + # Note that we cannot rely on OmegaConf validation here because we did not yet merge with the Hydra schema node + if not isinstance(searchpath, MutableSequence): + _err() + for v in searchpath: + if not isinstance(v, str): + _err() + + new_csp = copy.deepcopy(self.config_search_path) + schema = new_csp.get_path().pop(-1) + assert schema.provider == "schema" + for sp in searchpath: + new_csp.append(provider=provider, path=sp) + new_csp.append("schema", "structured://") + repo.initialize_sources(new_csp) + def _load_configuration_impl( self, config_name: Optional[str], @@ -165,6 +221,8 @@ def _load_configuration_impl( parser = OverridesParser.create() parsed_overrides = parser.parse_overrides(overrides=overrides) + self._process_config_searchpath(config_name, parsed_overrides, caching_repo) + self.validate_sweep_overrides_legal( overrides=parsed_overrides, run_mode=run_mode, from_shell=from_shell ) @@ -401,7 +459,17 @@ def _load_single_config( assert isinstance(merged, DictConfig) - return self._embed_result_config(ret, default.package) + res = self._embed_result_config(ret, default.package) + if ( + not default.primary + and config_path != "hydra/config" + and OmegaConf.select(res.config, "hydra.searchpath") is not None + ): + raise ConfigCompositionException( + f"In '{config_path}': Overriding hydra.searchpath is only supported from the primary config" + ) + + return res @staticmethod def _embed_result_config( diff --git a/hydra/_internal/config_repository.py b/hydra/_internal/config_repository.py index 40f594cbad5..047c20a996f 100644 --- a/hydra/_internal/config_repository.py +++ b/hydra/_internal/config_repository.py @@ -50,6 +50,10 @@ def get_group_options( def get_sources(self) -> List[ConfigSource]: ... + @abstractmethod + def initialize_sources(self, config_search_path: ConfigSearchPath) -> None: + ... + class ConfigRepository(IConfigRepository): @@ -57,6 +61,9 @@ class ConfigRepository(IConfigRepository): sources: List[ConfigSource] def __init__(self, config_search_path: ConfigSearchPath) -> None: + self.initialize_sources(config_search_path) + + def initialize_sources(self, config_search_path: ConfigSearchPath) -> None: self.sources = [] for search_path in config_search_path.get_path(): assert search_path.path is not None @@ -71,7 +78,7 @@ def get_schema_source(self) -> ConfigSource: assert ( source.__class__.__name__ == "StructuredConfigSource" and source.provider == "schema" - ) + ), "schema config source must be last" return source def load_config(self, config_path: str) -> Optional[ConfigResult]: @@ -313,6 +320,12 @@ def __init__(self, delegate: IConfigRepository): def get_schema_source(self) -> ConfigSource: return self.delegate.get_schema_source() + def initialize_sources(self, config_search_path: ConfigSearchPath) -> None: + self.delegate.initialize_sources(config_search_path) + # not clearing the cache. + # For the use case this is used, the only thing in the cache is the primary config + # and we want to keep it even though we re-initialized the sources. + def load_config(self, config_path: str) -> Optional[ConfigResult]: cache_key = f"config_path={config_path}" if cache_key in self.cache: diff --git a/hydra/_internal/config_search_path_impl.py b/hydra/_internal/config_search_path_impl.py index e4130c7d477..ab186e07b13 100644 --- a/hydra/_internal/config_search_path_impl.py +++ b/hydra/_internal/config_search_path_impl.py @@ -1,5 +1,5 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved -from typing import List, Optional, Sequence +from typing import List, MutableSequence, Optional from hydra.core.config_search_path import ( ConfigSearchPath, @@ -14,7 +14,7 @@ class ConfigSearchPathImpl(ConfigSearchPath): def __init__(self) -> None: self.config_search_path = [] - def get_path(self) -> Sequence[SearchPathElement]: + def get_path(self) -> MutableSequence[SearchPathElement]: return self.config_search_path def find_last_match(self, reference: SearchPathQuery) -> int: diff --git a/hydra/conf/__init__.py b/hydra/conf/__init__.py index d3569afc588..2e34280717b 100644 --- a/hydra/conf/__init__.py +++ b/hydra/conf/__init__.py @@ -101,6 +101,10 @@ class HydraConf: ] ) + # Elements to append to the config search path. + # Note: This can only be configured in the primary config. + searchpath: List[str] = field(default_factory=list) + # Normal run output configuration run: RunDir = RunDir() # Multi-run output configuration diff --git a/hydra/core/config_search_path.py b/hydra/core/config_search_path.py index f329528b172..b2e9772fbe2 100644 --- a/hydra/core/config_search_path.py +++ b/hydra/core/config_search_path.py @@ -1,7 +1,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Optional, Sequence +from typing import MutableSequence, Optional class SearchPathElement: @@ -28,7 +28,7 @@ class SearchPathQuery: class ConfigSearchPath(ABC): @abstractmethod - def get_path(self) -> Sequence[SearchPathElement]: + def get_path(self) -> MutableSequence[SearchPathElement]: ... @abstractmethod diff --git a/news/274.feature b/news/274.feature new file mode 100644 index 00000000000..8a1c0e4c572 --- /dev/null +++ b/news/274.feature @@ -0,0 +1 @@ +Support for configuring the config search path from the primary config diff --git a/tests/test_compose.py b/tests/test_compose.py index 59ad65c330a..d05368b5033 100644 --- a/tests/test_compose.py +++ b/tests/test_compose.py @@ -7,7 +7,7 @@ from typing import Any, Dict, List, Optional from omegaconf import OmegaConf -from pytest import mark, param, raises +from pytest import fixture, mark, param, raises from hydra._internal.config_search_path_impl import ConfigSearchPathImpl from hydra.core.config_search_path import SearchPathQuery @@ -319,8 +319,9 @@ class Config: assert cfg == expected +@mark.usefixtures("hydra_restore_singletons") class TestAdd: - def test_add(self, hydra_restore_singletons: Any) -> None: + def test_add(self) -> None: ConfigStore.instance().store(name="config", node={"key": 0}) with initialize(): with raises( @@ -332,7 +333,7 @@ def test_add(self, hydra_restore_singletons: Any) -> None: cfg = compose(config_name="config", overrides=["key=1"]) assert cfg == {"key": 1} - def test_force_add(self, hydra_restore_singletons: Any) -> None: + def test_force_add(self) -> None: ConfigStore.instance().store(name="config", node={"key": 0}) with initialize(): cfg = compose(config_name="config", overrides=["++key=1"]) @@ -341,7 +342,7 @@ def test_force_add(self, hydra_restore_singletons: Any) -> None: cfg = compose(config_name="config", overrides=["++key2=1"]) assert cfg == {"key": 0, "key2": 1} - def test_add_config_group(self, hydra_restore_singletons: Any) -> None: + def test_add_config_group(self) -> None: ConfigStore.instance().store(group="group", name="a0", node={"key": 0}) ConfigStore.instance().store(group="group", name="a1", node={"key": 1}) with initialize(): @@ -361,3 +362,140 @@ def test_add_config_group(self, hydra_restore_singletons: Any) -> None: ), ): compose(overrides=["++group=a1"]) + + +@mark.usefixtures("hydra_restore_singletons") +class TestConfigSearchPathOverride: + @fixture + def init_configs(self) -> Any: + cs = ConfigStore.instance() + cs.store( + name="with_sp", + node={"hydra": {"searchpath": ["pkg://hydra.test_utils.configs"]}}, + ) + cs.store(name="without_sp", node={}) + + cs.store(name="bad1", node={"hydra": {"searchpath": 42}}) + cs.store(name="bad2", node={"hydra": {"searchpath": [42]}}) + + # Using this triggers an error. Only primary configs are allowed to override hydra.searchpath + cs.store( + group="group2", + name="overriding_sp", + node={"hydra": {"searchpath": ["abc"]}}, + package="_global_", + ) + yield + + @fixture + def initialize_hydra(self) -> Any: + try: + init = initialize() + init.__enter__() + yield + finally: + init.__exit__(*sys.exc_info()) + + @mark.parametrize( + ("config_name", "overrides", "expected"), + [ + # config group is interpreted as simple config value addition. + param("without_sp", ["+group1=file1"], {"group1": "file1"}, id="without"), + param("with_sp", ["+group1=file1"], {"foo": 10}, id="with"), + # Overriding hydra.searchpath + param( + "without_sp", + ["hydra.searchpath=[pkg://hydra.test_utils.configs]", "+group1=file1"], + {"foo": 10}, + id="sp_added_by_override", + ), + param( + "with_sp", + ["hydra.searchpath=[]", "+group1=file1"], + {"group1": "file1"}, + id="sp_removed_by_override", + ), + ], + ) + def test_searchpath_in_primary_config( + self, + initialize_hydra: Any, + init_configs: Any, + config_name: str, + overrides: List[str], + expected: Any, + ) -> None: + cfg = compose(config_name=config_name, overrides=overrides) + assert cfg == expected + + @mark.parametrize( + ("config_name", "overrides", "expected"), + [ + param( + "bad1", + [], + raises( + ConfigCompositionException, + match=re.escape( + "hydra.searchpath must be a list of strings. Got: 42" + ), + ), + id="bad_cp_in_config", + ), + param( + "bad2", + [], + raises( + ConfigCompositionException, + match=re.escape( + "hydra.searchpath must be a list of strings. Got: [42]" + ), + ), + id="bad_cp_element_in_config", + ), + param( + "without_sp", + ["hydra.searchpath=42"], + raises( + ConfigCompositionException, + match=re.escape( + "hydra.searchpath must be a list of strings. Got: 42" + ), + ), + id="bad_override1", + ), + param( + "without_sp", + ["hydra.searchpath=[42]"], + raises( + ConfigCompositionException, + match=re.escape( + "hydra.searchpath must be a list of strings. Got: [42]" + ), + ), + id="bad_override2", + ), + param( + "without_sp", + ["+group2=overriding_sp"], + raises( + ConfigCompositionException, + match=re.escape( + "In 'group2/overriding_sp': Overriding hydra.searchpath " + "is only supported from the primary config" + ), + ), + id="overriding_sp_from_non_primary_config", + ), + ], + ) + def test_searchpath_config_errors( + self, + initialize_hydra: Any, + init_configs: Any, + config_name: str, + overrides: List[str], + expected: Any, + ) -> None: + with expected: + compose(config_name=config_name, overrides=overrides)