Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

274 searchpath #1450

Merged
merged 4 commits into from
Mar 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 71 additions & 3 deletions hydra/_internal/config_loader_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from collections import defaultdict
from dataclasses import dataclass
from textwrap import dedent
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, MutableSequence, Optional

from omegaconf import DictConfig, OmegaConf, flag_override, open_dict
from omegaconf import Container, DictConfig, OmegaConf, flag_override, open_dict
from omegaconf.errors import (
ConfigAttributeError,
ConfigKeyError,
Expand Down Expand Up @@ -152,6 +152,62 @@ def load_configuration(
except OmegaConfBaseException as e:
raise ConfigCompositionException().with_traceback(sys.exc_info()[2]) from e

def _process_config_searchpath(
self,
config_name: Optional[str],
parsed_overrides: List[Override],
repo: CachingConfigRepository,
) -> None:
if config_name is not None:
loaded = repo.load_config(config_path=config_name)
primary_config: Container
if loaded is None:
primary_config = OmegaConf.create()
else:
primary_config = loaded.config
else:
primary_config = OmegaConf.create()

def is_searchpath_override(v: Override) -> bool:
return v.get_key_element() == "hydra.searchpath"

override = None
for v in parsed_overrides:
if is_searchpath_override(v):
override = v.value()
break

searchpath = OmegaConf.select(primary_config, "hydra.searchpath")
if override is not None:
provider = "hydra.searchpath in command-line"
searchpath = override
else:
provider = "hydra.searchpath in main"

def _err() -> None:
raise ConfigCompositionException(
f"hydra.searchpath must be a list of strings. Got: {searchpath}"
)

if searchpath is None:
return

# validate hydra.searchpath.
# Note that we cannot rely on OmegaConf validation here because we did not yet merge with the Hydra schema node
if not isinstance(searchpath, MutableSequence):
_err()
for v in searchpath:
if not isinstance(v, str):
_err()

new_csp = copy.deepcopy(self.config_search_path)
schema = new_csp.get_path().pop(-1)
assert schema.provider == "schema"
for sp in searchpath:
new_csp.append(provider=provider, path=sp)
new_csp.append("schema", "structured://")
repo.initialize_sources(new_csp)

def _load_configuration_impl(
self,
config_name: Optional[str],
Expand All @@ -165,6 +221,8 @@ def _load_configuration_impl(
parser = OverridesParser.create()
parsed_overrides = parser.parse_overrides(overrides=overrides)

self._process_config_searchpath(config_name, parsed_overrides, caching_repo)

self.validate_sweep_overrides_legal(
overrides=parsed_overrides, run_mode=run_mode, from_shell=from_shell
)
Expand Down Expand Up @@ -401,7 +459,17 @@ def _load_single_config(

assert isinstance(merged, DictConfig)

return self._embed_result_config(ret, default.package)
res = self._embed_result_config(ret, default.package)
if (
not default.primary
and config_path != "hydra/config"
and OmegaConf.select(res.config, "hydra.searchpath") is not None
):
raise ConfigCompositionException(
f"In '{config_path}': Overriding hydra.searchpath is only supported from the primary config"
)

return res

@staticmethod
def _embed_result_config(
Expand Down
15 changes: 14 additions & 1 deletion hydra/_internal/config_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,20 @@ def get_group_options(
def get_sources(self) -> List[ConfigSource]:
...

@abstractmethod
def initialize_sources(self, config_search_path: ConfigSearchPath) -> None:
...


class ConfigRepository(IConfigRepository):

config_search_path: ConfigSearchPath
sources: List[ConfigSource]

def __init__(self, config_search_path: ConfigSearchPath) -> None:
self.initialize_sources(config_search_path)

def initialize_sources(self, config_search_path: ConfigSearchPath) -> None:
self.sources = []
for search_path in config_search_path.get_path():
assert search_path.path is not None
Expand All @@ -71,7 +78,7 @@ def get_schema_source(self) -> ConfigSource:
assert (
source.__class__.__name__ == "StructuredConfigSource"
and source.provider == "schema"
)
), "schema config source must be last"
return source

def load_config(self, config_path: str) -> Optional[ConfigResult]:
Expand Down Expand Up @@ -313,6 +320,12 @@ def __init__(self, delegate: IConfigRepository):
def get_schema_source(self) -> ConfigSource:
return self.delegate.get_schema_source()

def initialize_sources(self, config_search_path: ConfigSearchPath) -> None:
self.delegate.initialize_sources(config_search_path)
# not clearing the cache.
# For the use case this is used, the only thing in the cache is the primary config
# and we want to keep it even though we re-initialized the sources.

def load_config(self, config_path: str) -> Optional[ConfigResult]:
cache_key = f"config_path={config_path}"
if cache_key in self.cache:
Expand Down
4 changes: 2 additions & 2 deletions hydra/_internal/config_search_path_impl.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import List, Optional, Sequence
from typing import List, MutableSequence, Optional

from hydra.core.config_search_path import (
ConfigSearchPath,
Expand All @@ -14,7 +14,7 @@ class ConfigSearchPathImpl(ConfigSearchPath):
def __init__(self) -> None:
self.config_search_path = []

def get_path(self) -> Sequence[SearchPathElement]:
def get_path(self) -> MutableSequence[SearchPathElement]:
return self.config_search_path

def find_last_match(self, reference: SearchPathQuery) -> int:
Expand Down
4 changes: 4 additions & 0 deletions hydra/conf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ class HydraConf:
]
)

# Elements to append to the config search path.
# Note: This can only be configured in the primary config.
searchpath: List[str] = field(default_factory=list)

# Normal run output configuration
run: RunDir = RunDir()
# Multi-run output configuration
Expand Down
4 changes: 2 additions & 2 deletions hydra/core/config_search_path.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional, Sequence
from typing import MutableSequence, Optional


class SearchPathElement:
Expand All @@ -28,7 +28,7 @@ class SearchPathQuery:

class ConfigSearchPath(ABC):
@abstractmethod
def get_path(self) -> Sequence[SearchPathElement]:
def get_path(self) -> MutableSequence[SearchPathElement]:
...

@abstractmethod
Expand Down
1 change: 1 addition & 0 deletions news/274.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Support for configuring the config search path from the primary config
Loading