diff --git a/examples/plugins/example_launcher_plugin/hydra_plugins/example_launcher_plugin/example_launcher.py b/examples/plugins/example_launcher_plugin/hydra_plugins/example_launcher_plugin/example_launcher.py index a750a27d552..b15f0d6b6d8 100644 --- a/examples/plugins/example_launcher_plugin/hydra_plugins/example_launcher_plugin/example_launcher.py +++ b/examples/plugins/example_launcher_plugin/hydra_plugins/example_launcher_plugin/example_launcher.py @@ -5,7 +5,7 @@ from pathlib import Path from typing import Optional, Sequence -from hydra.core.config_loader import ConfigLoader +from hydra.types import HydraContext from hydra.core.config_store import ConfigStore from hydra.core.singleton import Singleton from hydra.core.utils import ( @@ -48,8 +48,8 @@ class LauncherConfig: class ExampleLauncher(Launcher): def __init__(self, foo: str, bar: str) -> None: self.config: Optional[DictConfig] = None - self.config_loader: Optional[ConfigLoader] = None self.task_function: Optional[TaskFunction] = None + self.hydra_context: Optional[HydraContext] = None # foo and var are coming from the the plugin's configuration self.foo = foo @@ -57,12 +57,13 @@ def __init__(self, foo: str, bar: str) -> None: def setup( self, - config: DictConfig, - config_loader: ConfigLoader, + *, + hydra_context: HydraContext, task_function: TaskFunction, + config: DictConfig, ) -> None: self.config = config - self.config_loader = config_loader + self.hydra_context = hydra_context self.task_function = task_function def launch( @@ -75,7 +76,7 @@ def launch( """ setup_globals() assert self.config is not None - assert self.config_loader is not None + assert self.hydra_context is not None assert self.task_function is not None configure_log(self.config.hydra.hydra_logging, self.config.hydra.verbose) @@ -91,7 +92,7 @@ def launch( idx = initial_job_idx + idx lst = " ".join(filter_overrides(overrides)) log.info(f"\t#{idx} : {lst}") - sweep_config = self.config_loader.load_sweep_config( + sweep_config = self.hydra_context.config_loader.load_sweep_config( self.config, list(overrides) ) with open_dict(sweep_config): @@ -113,8 +114,9 @@ def launch( Singleton.set_state(state) ret = run_job( - config=sweep_config, + hydra_context=self.hydra_context, task_function=self.task_function, + config=sweep_config, job_dir_key="hydra.sweep.dir", job_subdir_key="hydra.sweep.subdir", ) diff --git a/examples/plugins/example_sweeper_plugin/hydra_plugins/example_sweeper_plugin/example_sweeper.py b/examples/plugins/example_sweeper_plugin/hydra_plugins/example_sweeper_plugin/example_sweeper.py index 05254ff75a0..0d15cabb518 100644 --- a/examples/plugins/example_sweeper_plugin/hydra_plugins/example_sweeper_plugin/example_sweeper.py +++ b/examples/plugins/example_sweeper_plugin/hydra_plugins/example_sweeper_plugin/example_sweeper.py @@ -6,7 +6,7 @@ from pathlib import Path from typing import Any, Iterable, List, Optional, Sequence -from hydra.core.config_loader import ConfigLoader +from hydra.types import HydraContext from hydra.core.config_store import ConfigStore from hydra.core.override_parser.overrides_parser import OverridesParser from hydra.core.plugins import Plugins @@ -45,21 +45,23 @@ def __init__(self, max_batch_size: int, foo: str, bar: str): self.max_batch_size = max_batch_size self.config: Optional[DictConfig] = None self.launcher: Optional[Launcher] = None + self.hydra_context: Optional[HydraContext] = None self.job_results = None self.foo = foo self.bar = bar def setup( self, - config: DictConfig, - config_loader: ConfigLoader, + *, + hydra_context: HydraContext, task_function: TaskFunction, + config: DictConfig, ) -> None: self.config = config - self.config_loader = config_loader self.launcher = Plugins.instance().instantiate_launcher( - config=config, config_loader=config_loader, task_function=task_function + hydra_context=hydra_context, task_function=task_function, config=config ) + self.hydra_context = hydra_context def sweep(self, arguments: List[str]) -> Any: assert self.config is not None diff --git a/hydra/_internal/core_plugins/basic_launcher.py b/hydra/_internal/core_plugins/basic_launcher.py index fbad1a0fc13..af1c3493282 100644 --- a/hydra/_internal/core_plugins/basic_launcher.py +++ b/hydra/_internal/core_plugins/basic_launcher.py @@ -6,7 +6,6 @@ from omegaconf import DictConfig, open_dict -from hydra.core.config_loader import ConfigLoader from hydra.core.config_store import ConfigStore from hydra.core.utils import ( JobReturn, @@ -16,7 +15,7 @@ setup_globals, ) from hydra.plugins.launcher import Launcher -from hydra.types import TaskFunction +from hydra.types import HydraContext, TaskFunction log = logging.getLogger(__name__) @@ -35,26 +34,27 @@ class BasicLauncher(Launcher): def __init__(self) -> None: super().__init__() self.config: Optional[DictConfig] = None - self.config_loader: Optional[ConfigLoader] = None self.task_function: Optional[TaskFunction] = None + self.hydra_context: Optional[HydraContext] = None def setup( self, - config: DictConfig, - config_loader: ConfigLoader, + *, + hydra_context: HydraContext, task_function: TaskFunction, + config: DictConfig, ) -> None: self.config = config - self.config_loader = config_loader + self.hydra_context = hydra_context self.task_function = task_function def launch( self, job_overrides: Sequence[Sequence[str]], initial_job_idx: int ) -> Sequence[JobReturn]: setup_globals() + assert self.hydra_context is not None assert self.config is not None assert self.task_function is not None - assert self.config_loader is not None configure_log(self.config.hydra.hydra_logging, self.config.hydra.verbose) sweep_dir = self.config.hydra.sweep.dir @@ -65,15 +65,16 @@ def launch( idx = initial_job_idx + idx lst = " ".join(filter_overrides(overrides)) log.info(f"\t#{idx} : {lst}") - sweep_config = self.config_loader.load_sweep_config( + sweep_config = self.hydra_context.config_loader.load_sweep_config( self.config, list(overrides) ) with open_dict(sweep_config): sweep_config.hydra.job.id = idx sweep_config.hydra.job.num = idx ret = run_job( - config=sweep_config, + hydra_context=self.hydra_context, task_function=self.task_function, + config=sweep_config, job_dir_key="hydra.sweep.dir", job_subdir_key="hydra.sweep.subdir", ) diff --git a/hydra/_internal/core_plugins/basic_sweeper.py b/hydra/_internal/core_plugins/basic_sweeper.py index 57583933b65..f1b41264413 100644 --- a/hydra/_internal/core_plugins/basic_sweeper.py +++ b/hydra/_internal/core_plugins/basic_sweeper.py @@ -25,7 +25,6 @@ from omegaconf import DictConfig, OmegaConf -from hydra.core.config_loader import ConfigLoader from hydra.core.config_store import ConfigStore from hydra.core.override_parser.overrides_parser import OverridesParser from hydra.core.override_parser.types import Override @@ -33,7 +32,7 @@ from hydra.errors import HydraException from hydra.plugins.launcher import Launcher from hydra.plugins.sweeper import Sweeper -from hydra.types import TaskFunction +from hydra.types import HydraContext, TaskFunction @dataclass @@ -64,23 +63,26 @@ def __init__(self, max_batch_size: Optional[int]) -> None: self.batch_index = 0 self.max_batch_size = max_batch_size - self.config_loader: Optional[ConfigLoader] = None + self.hydra_context: Optional[HydraContext] = None self.config: Optional[DictConfig] = None self.launcher: Optional[Launcher] = None def setup( self, - config: DictConfig, - config_loader: ConfigLoader, + *, + hydra_context: HydraContext, task_function: TaskFunction, + config: DictConfig, ) -> None: from hydra.core.plugins import Plugins - self.config_loader = config_loader + self.hydra_context = hydra_context self.config = config self.launcher = Plugins.instance().instantiate_launcher( - config=config, config_loader=config_loader, task_function=task_function + hydra_context=hydra_context, + task_function=task_function, + config=config, ) @staticmethod @@ -128,8 +130,9 @@ def split_arguments( def sweep(self, arguments: List[str]) -> Any: assert self.config is not None assert self.launcher is not None + assert self.hydra_context is not None - parser = OverridesParser.create(config_loader=self.config_loader) + parser = OverridesParser.create(config_loader=self.hydra_context.config_loader) overrides = parser.parse_overrides(arguments) self.overrides = self.split_arguments(overrides, self.max_batch_size) diff --git a/hydra/_internal/hydra.py b/hydra/_internal/hydra.py index 1001647a6e6..56876af708b 100644 --- a/hydra/_internal/hydra.py +++ b/hydra/_internal/hydra.py @@ -26,7 +26,7 @@ from hydra.plugins.launcher import Launcher from hydra.plugins.search_path_plugin import SearchPathPlugin from hydra.plugins.sweeper import Sweeper -from hydra.types import RunMode, TaskFunction +from hydra.types import HydraContext, RunMode, TaskFunction from ..core.default_element import DefaultsTreeNode, InputDefault from .callbacks import Callbacks @@ -95,8 +95,11 @@ def run( callbacks.on_run_start(config=cfg, config_name=config_name) ret = run_job( - config=cfg, + hydra_context=HydraContext( + config_loader=self.config_loader, callbacks=callbacks + ), task_function=task_function, + config=cfg, job_dir_key="hydra.run.dir", job_subdir_key=None, configure_logging=with_log_configuration, @@ -125,7 +128,11 @@ def multirun( callbacks.on_multirun_start(config=cfg, config_name=config_name) sweeper = Plugins.instance().instantiate_sweeper( - config=cfg, config_loader=self.config_loader, task_function=task_function + config=cfg, + hydra_context=HydraContext( + config_loader=self.config_loader, callbacks=callbacks + ), + task_function=task_function, ) task_overrides = OmegaConf.to_container(cfg.hydra.overrides.task, resolve=False) assert isinstance(task_overrides, list) diff --git a/hydra/core/plugins.py b/hydra/core/plugins.py index 4f483018415..e0cf4165c78 100644 --- a/hydra/core/plugins.py +++ b/hydra/core/plugins.py @@ -12,7 +12,6 @@ from omegaconf import DictConfig from hydra._internal.sources_registry import SourcesRegistry -from hydra.core.config_loader import ConfigLoader from hydra.core.singleton import Singleton from hydra.plugins.completion_plugin import CompletionPlugin from hydra.plugins.config_source import ConfigSource @@ -20,7 +19,7 @@ from hydra.plugins.plugin import Plugin from hydra.plugins.search_path_plugin import SearchPathPlugin from hydra.plugins.sweeper import Sweeper -from hydra.types import TaskFunction +from hydra.types import HydraContext, TaskFunction from hydra.utils import instantiate @@ -106,9 +105,10 @@ def is_in_toplevel_plugins_module(clazz: str) -> bool: def instantiate_sweeper( self, - config: DictConfig, - config_loader: ConfigLoader, + *, + hydra_context: HydraContext, task_function: TaskFunction, + config: DictConfig, ) -> Sweeper: Plugins.check_usage(self) if config.hydra.sweeper is None: @@ -116,15 +116,16 @@ def instantiate_sweeper( sweeper = self._instantiate(config.hydra.sweeper) assert isinstance(sweeper, Sweeper) sweeper.setup( - config=config, config_loader=config_loader, task_function=task_function + config=config, hydra_context=hydra_context, task_function=task_function ) return sweeper def instantiate_launcher( self, - config: DictConfig, - config_loader: ConfigLoader, + *, + hydra_context: HydraContext, task_function: TaskFunction, + config: DictConfig, ) -> Launcher: Plugins.check_usage(self) if config.hydra.launcher is None: @@ -132,7 +133,7 @@ def instantiate_launcher( launcher = self._instantiate(config.hydra.launcher) assert isinstance(launcher, Launcher) launcher.setup( - config=config, config_loader=config_loader, task_function=task_function + config=config, hydra_context=hydra_context, task_function=task_function ) return launcher diff --git a/hydra/core/utils.py b/hydra/core/utils.py index 79d95b436cb..a4ec446ebe3 100644 --- a/hydra/core/utils.py +++ b/hydra/core/utils.py @@ -18,7 +18,7 @@ from hydra.core.hydra_config import HydraConfig from hydra.core.singleton import Singleton from hydra.errors import HydraJobException -from hydra.types import TaskFunction +from hydra.types import HydraContext, TaskFunction log = logging.getLogger(__name__) @@ -83,15 +83,16 @@ def filter_overrides(overrides: Sequence[str]) -> Sequence[str]: def run_job( - config: DictConfig, + *, + hydra_context: HydraContext, task_function: TaskFunction, + config: DictConfig, job_dir_key: str, job_subdir_key: Optional[str], configure_logging: bool = True, ) -> "JobReturn": - from hydra._internal.callbacks import Callbacks - callbacks = Callbacks(config) + callbacks = hydra_context.callbacks old_cwd = os.getcwd() orig_hydra_cfg = HydraConfig.instance().cfg diff --git a/hydra/plugins/launcher.py b/hydra/plugins/launcher.py index 91c25e4c930..b11ad4d32b8 100644 --- a/hydra/plugins/launcher.py +++ b/hydra/plugins/launcher.py @@ -7,9 +7,9 @@ from omegaconf import DictConfig -from hydra.core.config_loader import ConfigLoader from hydra.core.utils import JobReturn -from hydra.types import TaskFunction + +from hydra.types import TaskFunction, HydraContext from .plugin import Plugin @@ -18,9 +18,10 @@ class Launcher(Plugin): @abstractmethod def setup( self, - config: DictConfig, - config_loader: ConfigLoader, + *, + hydra_context: HydraContext, task_function: TaskFunction, + config: DictConfig, ) -> None: """ Sets this launcher instance up. diff --git a/hydra/plugins/sweeper.py b/hydra/plugins/sweeper.py index bbc27cdc599..62409ff6ff7 100644 --- a/hydra/plugins/sweeper.py +++ b/hydra/plugins/sweeper.py @@ -5,12 +5,12 @@ from abc import abstractmethod from typing import Any, List, Sequence, Optional -from hydra.core.config_loader import ConfigLoader from hydra.types import TaskFunction from omegaconf import DictConfig from .launcher import Launcher from .plugin import Plugin +from hydra.types import HydraContext class Sweeper(Plugin): @@ -20,16 +20,17 @@ class Sweeper(Plugin): (where each job typically takes a different command line arguments) """ - config_loader: Optional[ConfigLoader] + hydra_context: Optional[HydraContext] config: Optional[DictConfig] launcher: Optional[Launcher] @abstractmethod def setup( self, - config: DictConfig, - config_loader: ConfigLoader, + *, + hydra_context: HydraContext, task_function: TaskFunction, + config: DictConfig, ) -> None: raise NotImplementedError() @@ -50,9 +51,9 @@ def validate_batch_is_legal(self, batch: Sequence[Sequence[str]]) -> None: This repeat work the launcher will do, but as the launcher may be performing this in a different process/machine it's important to do it here as well to detect failures early. """ - assert self.config_loader is not None + assert self.hydra_context is not None assert self.config is not None for overrides in batch: - self.config_loader.load_sweep_config( + self.hydra_context.config_loader.load_sweep_config( master_config=self.config, sweep_overrides=list(overrides) ) diff --git a/hydra/types.py b/hydra/types.py index b082aba9728..539f0e59591 100644 --- a/hydra/types.py +++ b/hydra/types.py @@ -2,13 +2,24 @@ import warnings from dataclasses import dataclass from enum import Enum -from typing import Any, Callable +from typing import TYPE_CHECKING, Any, Callable from omegaconf import MISSING TaskFunction = Callable[[Any], Any] +if TYPE_CHECKING: + from hydra._internal.callbacks import Callbacks + from hydra.core.config_loader import ConfigLoader + + +@dataclass +class HydraContext: + config_loader: "ConfigLoader" + callbacks: "Callbacks" + + @dataclass class TargetConf: """ diff --git a/news/1498.api_change b/news/1498.api_change new file mode 100644 index 00000000000..39395e901e6 --- /dev/null +++ b/news/1498.api_change @@ -0,0 +1 @@ +HydraContext required in run_job, Launcher and Sweeper's setup methods, see issue for details diff --git a/plugins/hydra_ax_sweeper/hydra_plugins/hydra_ax_sweeper/_core.py b/plugins/hydra_ax_sweeper/hydra_plugins/hydra_ax_sweeper/_core.py index 9fe0431f3bb..e442da73411 100644 --- a/plugins/hydra_ax_sweeper/hydra_plugins/hydra_ax_sweeper/_core.py +++ b/plugins/hydra_ax_sweeper/hydra_plugins/hydra_ax_sweeper/_core.py @@ -6,13 +6,12 @@ from ax.core import types as ax_types # type: ignore from ax.exceptions.core import SearchSpaceExhausted # type: ignore from ax.service.ax_client import AxClient # type: ignore -from hydra.core.config_loader import ConfigLoader from hydra.core.override_parser.overrides_parser import OverridesParser from hydra.core.override_parser.types import IntervalSweep, Override, Transformer from hydra.core.plugins import Plugins from hydra.plugins.launcher import Launcher from hydra.plugins.sweeper import Sweeper -from hydra.types import TaskFunction +from hydra.types import HydraContext, TaskFunction from omegaconf import DictConfig, OmegaConf from ._earlystopper import EarlyStopper @@ -107,9 +106,9 @@ class CoreAxSweeper(Sweeper): """Class to interface with the Ax Platform""" def __init__(self, ax_config: AxConfig, max_batch_size: Optional[int]): - self.config_loader: Optional[ConfigLoader] = None self.config: Optional[DictConfig] = None self.launcher: Optional[Launcher] = None + self.hydra_context: Optional[HydraContext] = None self.job_results = None self.experiment: ExperimentConfig = ax_config.experiment @@ -130,14 +129,15 @@ def __init__(self, ax_config: AxConfig, max_batch_size: Optional[int]): def setup( self, - config: DictConfig, - config_loader: ConfigLoader, + *, + hydra_context: HydraContext, task_function: TaskFunction, + config: DictConfig, ) -> None: self.config = config - self.config_loader = config_loader + self.hydra_context = hydra_context self.launcher = Plugins.instance().instantiate_launcher( - config=config, config_loader=config_loader, task_function=task_function + config=config, hydra_context=hydra_context, task_function=task_function ) self.sweep_dir = config.hydra.sweep.dir diff --git a/plugins/hydra_ax_sweeper/hydra_plugins/hydra_ax_sweeper/ax_sweeper.py b/plugins/hydra_ax_sweeper/hydra_plugins/hydra_ax_sweeper/ax_sweeper.py index 173a342bfba..8a43260908b 100644 --- a/plugins/hydra_ax_sweeper/hydra_plugins/hydra_ax_sweeper/ax_sweeper.py +++ b/plugins/hydra_ax_sweeper/hydra_plugins/hydra_ax_sweeper/ax_sweeper.py @@ -1,9 +1,8 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved from typing import List, Optional -from hydra.core.config_loader import ConfigLoader from hydra.plugins.sweeper import Sweeper -from hydra.types import TaskFunction +from hydra.types import HydraContext, TaskFunction from omegaconf import DictConfig from .config import AxConfig @@ -19,11 +18,14 @@ def __init__(self, ax_config: AxConfig, max_batch_size: Optional[int]): def setup( self, - config: DictConfig, - config_loader: ConfigLoader, + *, + hydra_context: HydraContext, task_function: TaskFunction, + config: DictConfig, ) -> None: - return self.sweeper.setup(config, config_loader, task_function) + return self.sweeper.setup( + hydra_context=hydra_context, task_function=task_function, config=config + ) def sweep(self, arguments: List[str]) -> None: return self.sweeper.sweep(arguments) diff --git a/plugins/hydra_joblib_launcher/hydra_plugins/hydra_joblib_launcher/_core.py b/plugins/hydra_joblib_launcher/hydra_plugins/hydra_joblib_launcher/_core.py index b2b2fdae813..917c849feff 100644 --- a/plugins/hydra_joblib_launcher/hydra_plugins/hydra_joblib_launcher/_core.py +++ b/plugins/hydra_joblib_launcher/hydra_plugins/hydra_joblib_launcher/_core.py @@ -3,7 +3,6 @@ from pathlib import Path from typing import Any, Dict, List, Sequence -from hydra.core.config_loader import ConfigLoader from hydra.core.hydra_config import HydraConfig from hydra.core.singleton import Singleton from hydra.core.utils import ( @@ -13,7 +12,7 @@ run_job, setup_globals, ) -from hydra.types import TaskFunction +from hydra.types import HydraContext, TaskFunction from joblib import Parallel, delayed # type: ignore from omegaconf import DictConfig, open_dict @@ -25,7 +24,7 @@ def execute_job( idx: int, overrides: Sequence[str], - config_loader: ConfigLoader, + hydra_context: HydraContext, config: DictConfig, task_function: TaskFunction, singleton_state: Dict[Any, Any], @@ -34,13 +33,16 @@ def execute_job( setup_globals() Singleton.set_state(singleton_state) - sweep_config = config_loader.load_sweep_config(config, list(overrides)) + sweep_config = hydra_context.config_loader.load_sweep_config( + config, list(overrides) + ) with open_dict(sweep_config): sweep_config.hydra.job.id = "{}_{}".format(sweep_config.hydra.job.name, idx) sweep_config.hydra.job.num = idx HydraConfig.instance().set_config(sweep_config) ret = run_job( + hydra_context=hydra_context, config=sweep_config, task_function=task_function, job_dir_key="hydra.sweep.dir", @@ -73,8 +75,8 @@ def launch( """ setup_globals() assert launcher.config is not None - assert launcher.config_loader is not None assert launcher.task_function is not None + assert launcher.hydra_context is not None configure_log(launcher.config.hydra.hydra_logging, launcher.config.hydra.verbose) sweep_dir = Path(str(launcher.config.hydra.sweep.dir)) @@ -102,7 +104,7 @@ def launch( delayed(execute_job)( initial_job_idx + idx, overrides, - launcher.config_loader, + launcher.hydra_context, launcher.config, launcher.task_function, singleton_state, diff --git a/plugins/hydra_joblib_launcher/hydra_plugins/hydra_joblib_launcher/joblib_launcher.py b/plugins/hydra_joblib_launcher/hydra_plugins/hydra_joblib_launcher/joblib_launcher.py index 38f857d3503..66cb99d741e 100644 --- a/plugins/hydra_joblib_launcher/hydra_plugins/hydra_joblib_launcher/joblib_launcher.py +++ b/plugins/hydra_joblib_launcher/hydra_plugins/hydra_joblib_launcher/joblib_launcher.py @@ -2,10 +2,9 @@ import logging from typing import Any, Optional, Sequence -from hydra.core.config_loader import ConfigLoader from hydra.core.utils import JobReturn from hydra.plugins.launcher import Launcher -from hydra.types import TaskFunction +from hydra.types import HydraContext, TaskFunction from omegaconf import DictConfig log = logging.getLogger(__name__) @@ -22,20 +21,21 @@ def __init__(self, **kwargs: Any) -> None: https://github.com/facebookresearch/hydra/issues/357 """ self.config: Optional[DictConfig] = None - self.config_loader: Optional[ConfigLoader] = None self.task_function: Optional[TaskFunction] = None + self.hydra_context: Optional[HydraContext] = None self.joblib = kwargs def setup( self, - config: DictConfig, - config_loader: ConfigLoader, + *, + hydra_context: HydraContext, task_function: TaskFunction, + config: DictConfig, ) -> None: self.config = config - self.config_loader = config_loader self.task_function = task_function + self.hydra_context = hydra_context def launch( self, job_overrides: Sequence[Sequence[str]], initial_job_idx: int diff --git a/plugins/hydra_nevergrad_sweeper/hydra_plugins/hydra_nevergrad_sweeper/_impl.py b/plugins/hydra_nevergrad_sweeper/hydra_plugins/hydra_nevergrad_sweeper/_impl.py index 52b79b7c1c5..dec1295f2b8 100644 --- a/plugins/hydra_nevergrad_sweeper/hydra_plugins/hydra_nevergrad_sweeper/_impl.py +++ b/plugins/hydra_nevergrad_sweeper/hydra_plugins/hydra_nevergrad_sweeper/_impl.py @@ -12,7 +12,6 @@ ) import nevergrad as ng -from hydra.core.config_loader import ConfigLoader from hydra.core.override_parser.overrides_parser import OverridesParser from hydra.core.override_parser.types import ( ChoiceSweep, @@ -23,7 +22,7 @@ from hydra.core.plugins import Plugins from hydra.plugins.launcher import Launcher from hydra.plugins.sweeper import Sweeper -from hydra.types import TaskFunction +from hydra.types import HydraContext, TaskFunction from omegaconf import DictConfig, ListConfig, OmegaConf from .config import OptimConf, ScalarConfigSpec @@ -90,6 +89,7 @@ def __init__( self.opt_config = optim self.config: Optional[DictConfig] = None self.launcher: Optional[Launcher] = None + self.hydra_context: Optional[HydraContext] = None self.job_results = None self.parametrization: Dict[str, Any] = {} if parametrization is not None: @@ -102,15 +102,16 @@ def __init__( def setup( self, - config: DictConfig, - config_loader: ConfigLoader, + *, + hydra_context: HydraContext, task_function: TaskFunction, + config: DictConfig, ) -> None: self.job_idx = 0 self.config = config - self.config_loader = config_loader + self.hydra_context = hydra_context self.launcher = Plugins.instance().instantiate_launcher( - config=config, config_loader=config_loader, task_function=task_function + hydra_context=hydra_context, task_function=task_function, config=config ) def sweep(self, arguments: List[str]) -> None: diff --git a/plugins/hydra_nevergrad_sweeper/hydra_plugins/hydra_nevergrad_sweeper/nevergrad_sweeper.py b/plugins/hydra_nevergrad_sweeper/hydra_plugins/hydra_nevergrad_sweeper/nevergrad_sweeper.py index 3567739e067..be0d4ecf8f9 100644 --- a/plugins/hydra_nevergrad_sweeper/hydra_plugins/hydra_nevergrad_sweeper/nevergrad_sweeper.py +++ b/plugins/hydra_nevergrad_sweeper/hydra_plugins/hydra_nevergrad_sweeper/nevergrad_sweeper.py @@ -2,8 +2,8 @@ from typing import List, Optional from hydra import TaskFunction -from hydra.core.config_loader import ConfigLoader from hydra.plugins.sweeper import Sweeper +from hydra.types import HydraContext from omegaconf import DictConfig from .config import OptimConf @@ -19,11 +19,14 @@ def __init__(self, optim: OptimConf, parametrization: Optional[DictConfig]): def setup( self, - config: DictConfig, - config_loader: ConfigLoader, + *, + hydra_context: HydraContext, task_function: TaskFunction, + config: DictConfig, ) -> None: - return self.sweeper.setup(config, config_loader, task_function) + return self.sweeper.setup( + hydra_context=hydra_context, task_function=task_function, config=config + ) def sweep(self, arguments: List[str]) -> None: return self.sweeper.sweep(arguments) diff --git a/plugins/hydra_optuna_sweeper/hydra_plugins/hydra_optuna_sweeper/_impl.py b/plugins/hydra_optuna_sweeper/hydra_plugins/hydra_optuna_sweeper/_impl.py index 33e63e6d543..37667bfd5cd 100644 --- a/plugins/hydra_optuna_sweeper/hydra_plugins/hydra_optuna_sweeper/_impl.py +++ b/plugins/hydra_optuna_sweeper/hydra_plugins/hydra_optuna_sweeper/_impl.py @@ -4,7 +4,6 @@ from typing import Any, Dict, List, MutableMapping, MutableSequence, Optional import optuna -from hydra.core.config_loader import ConfigLoader from hydra.core.override_parser.overrides_parser import OverridesParser from hydra.core.override_parser.types import ( ChoiceSweep, @@ -15,7 +14,7 @@ ) from hydra.core.plugins import Plugins from hydra.plugins.sweeper import Sweeper -from hydra.types import TaskFunction +from hydra.types import HydraContext, TaskFunction from omegaconf import DictConfig, OmegaConf from optuna.distributions import ( BaseDistribution, @@ -135,21 +134,23 @@ def __init__( def setup( self, - config: DictConfig, - config_loader: ConfigLoader, + *, + hydra_context: HydraContext, task_function: TaskFunction, + config: DictConfig, ) -> None: self.job_idx = 0 self.config = config - self.config_loader = config_loader + self.hydra_context = hydra_context self.launcher = Plugins.instance().instantiate_launcher( - config=config, config_loader=config_loader, task_function=task_function + config=config, hydra_context=hydra_context, task_function=task_function ) self.sweep_dir = config.hydra.sweep.dir def sweep(self, arguments: List[str]) -> None: assert self.config is not None assert self.launcher is not None + assert self.hydra_context is not None assert self.job_idx is not None parser = OverridesParser.create() diff --git a/plugins/hydra_optuna_sweeper/hydra_plugins/hydra_optuna_sweeper/optuna_sweeper.py b/plugins/hydra_optuna_sweeper/hydra_plugins/hydra_optuna_sweeper/optuna_sweeper.py index e7e86a1ca18..a9b0072310d 100644 --- a/plugins/hydra_optuna_sweeper/hydra_plugins/hydra_optuna_sweeper/optuna_sweeper.py +++ b/plugins/hydra_optuna_sweeper/hydra_plugins/hydra_optuna_sweeper/optuna_sweeper.py @@ -1,9 +1,8 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved from typing import Any, List, Optional -from hydra.core.config_loader import ConfigLoader from hydra.plugins.sweeper import Sweeper -from hydra.types import TaskFunction +from hydra.types import HydraContext, TaskFunction from omegaconf import DictConfig from .config import SamplerConfig @@ -30,11 +29,14 @@ def __init__( def setup( self, - config: DictConfig, - config_loader: ConfigLoader, + *, + hydra_context: HydraContext, task_function: TaskFunction, + config: DictConfig, ) -> None: - self.sweeper.setup(config, config_loader, task_function) + self.sweeper.setup( + hydra_context=hydra_context, task_function=task_function, config=config + ) def sweep(self, arguments: List[str]) -> None: return self.sweeper.sweep(arguments) diff --git a/plugins/hydra_ray_launcher/hydra_plugins/hydra_ray_launcher/_core.py b/plugins/hydra_ray_launcher/hydra_plugins/hydra_ray_launcher/_core.py index ece74e8cba4..a2220df279e 100644 --- a/plugins/hydra_ray_launcher/hydra_plugins/hydra_ray_launcher/_core.py +++ b/plugins/hydra_ray_launcher/hydra_plugins/hydra_ray_launcher/_core.py @@ -24,7 +24,7 @@ def launch( ) -> Sequence[JobReturn]: setup_globals() assert launcher.config is not None - assert launcher.config_loader is not None + assert launcher.hydra_context is not None assert launcher.task_function is not None configure_log(launcher.config.hydra.hydra_logging, launcher.config.hydra.verbose) @@ -42,7 +42,7 @@ def launch( idx = initial_job_idx + idx ostr = " ".join(filter_overrides(overrides)) log.info(f"\t#{idx} : {ostr}") - sweep_config = launcher.config_loader.load_sweep_config( + sweep_config = launcher.hydra_context.config_loader.load_sweep_config( launcher.config, list(overrides) ) with open_dict(sweep_config): @@ -52,6 +52,7 @@ def launch( sweep_config.hydra.job.id = f"job_id_for_{idx}" sweep_config.hydra.job.num = idx ray_obj = launch_job_on_ray( + launcher.hydra_context, launcher.ray_cfg.remote, sweep_config, launcher.task_function, diff --git a/plugins/hydra_ray_launcher/hydra_plugins/hydra_ray_launcher/_core_aws.py b/plugins/hydra_ray_launcher/hydra_plugins/hydra_ray_launcher/_core_aws.py index fcd42b3f73e..cd5c1f3fe0d 100644 --- a/plugins/hydra_ray_launcher/hydra_plugins/hydra_ray_launcher/_core_aws.py +++ b/plugins/hydra_ray_launcher/hydra_plugins/hydra_ray_launcher/_core_aws.py @@ -54,7 +54,7 @@ def launch( ) -> Sequence[JobReturn]: setup_globals() assert launcher.config is not None - assert launcher.config_loader is not None + assert launcher.hydra_context is not None assert launcher.task_function is not None setup_commands = launcher.env_setup.commands @@ -80,7 +80,7 @@ def launch( idx = initial_job_idx + idx ostr = " ".join(filter_overrides(overrides)) log.info(f"\t#{idx} : {ostr}") - sweep_config = launcher.config_loader.load_sweep_config( + sweep_config = launcher.hydra_context.config_loader.load_sweep_config( launcher.config, list(overrides) ) with open_dict(sweep_config): @@ -91,6 +91,7 @@ def launch( _pickle_jobs( tmp_dir=local_tmp_dir, + hydra_context=launcher.hydra_context, sweep_configs=sweep_configs, # type: ignore task_function=launcher.task_function, singleton_state=Singleton.get_state(), diff --git a/plugins/hydra_ray_launcher/hydra_plugins/hydra_ray_launcher/_launcher_util.py b/plugins/hydra_ray_launcher/hydra_plugins/hydra_ray_launcher/_launcher_util.py index 85d0bb1fca3..57033aae213 100644 --- a/plugins/hydra_ray_launcher/hydra_plugins/hydra_ray_launcher/_launcher_util.py +++ b/plugins/hydra_ray_launcher/hydra_plugins/hydra_ray_launcher/_launcher_util.py @@ -9,7 +9,7 @@ from hydra.core.hydra_config import HydraConfig from hydra.core.singleton import Singleton from hydra.core.utils import JobReturn, run_job, setup_globals -from hydra.types import TaskFunction +from hydra.types import HydraContext, TaskFunction from omegaconf import DictConfig, OmegaConf log = logging.getLogger(__name__) @@ -30,6 +30,7 @@ def start_ray(init_cfg: DictConfig) -> None: def _run_job( + hydra_context: HydraContext, sweep_config: DictConfig, task_function: TaskFunction, singleton_state: Dict[Any, Any], @@ -38,14 +39,16 @@ def _run_job( Singleton.set_state(singleton_state) HydraConfig.instance().set_config(sweep_config) return run_job( - config=sweep_config, + hydra_context=hydra_context, task_function=task_function, + config=sweep_config, job_dir_key="hydra.sweep.dir", job_subdir_key="hydra.sweep.subdir", ) def launch_job_on_ray( + hydra_context: HydraContext, ray_remote: DictConfig, sweep_config: DictConfig, task_function: TaskFunction, @@ -57,6 +60,7 @@ def launch_job_on_ray( run_job_ray = ray.remote(_run_job) ret = run_job_ray.remote( + hydra_context=hydra_context, sweep_config=sweep_config, task_function=task_function, singleton_state=singleton_state, diff --git a/plugins/hydra_ray_launcher/hydra_plugins/hydra_ray_launcher/_remote_invoke.py b/plugins/hydra_ray_launcher/hydra_plugins/hydra_ray_launcher/_remote_invoke.py index 725d04fb959..afe046cbf47 100644 --- a/plugins/hydra_ray_launcher/hydra_plugins/hydra_ray_launcher/_remote_invoke.py +++ b/plugins/hydra_ray_launcher/hydra_plugins/hydra_ray_launcher/_remote_invoke.py @@ -30,6 +30,7 @@ def launch_jobs(temp_dir: str) -> None: runs = [] with open(os.path.join(temp_dir, JOB_SPEC_PICKLE), "rb") as f: job_spec = pickle.load(f) # nosec + hydra_context = job_spec["hydra_context"] singleton_state = job_spec["singleton_state"] sweep_configs = job_spec["sweep_configs"] task_function = job_spec["task_function"] @@ -55,7 +56,7 @@ def launch_jobs(temp_dir: str) -> None: start_ray(ray_init) ray_obj = launch_job_on_ray( - ray_remote, sweep_config, task_function, singleton_state + hydra_context, ray_remote, sweep_config, task_function, singleton_state ) runs.append(ray_obj) diff --git a/plugins/hydra_ray_launcher/hydra_plugins/hydra_ray_launcher/ray_aws_launcher.py b/plugins/hydra_ray_launcher/hydra_plugins/hydra_ray_launcher/ray_aws_launcher.py index a891b9c135c..1a37e175dab 100644 --- a/plugins/hydra_ray_launcher/hydra_plugins/hydra_ray_launcher/ray_aws_launcher.py +++ b/plugins/hydra_ray_launcher/hydra_plugins/hydra_ray_launcher/ray_aws_launcher.py @@ -2,10 +2,9 @@ import logging from typing import Optional, Sequence -from hydra.core.config_loader import ConfigLoader from hydra.core.utils import JobReturn from hydra.plugins.launcher import Launcher -from hydra.types import TaskFunction +from hydra.types import HydraContext, TaskFunction from omegaconf import DictConfig from hydra_plugins.hydra_ray_launcher._config import ( # type: ignore @@ -30,19 +29,20 @@ def __init__( self.sync_up = sync_up self.sync_down = sync_down self.config: Optional[DictConfig] = None - self.config_loader: Optional[ConfigLoader] = None + self.hydra_context: Optional[HydraContext] = None self.task_function: Optional[TaskFunction] = None self.ray_yaml_path: Optional[str] = None self.env_setup = env_setup def setup( self, - config: DictConfig, - config_loader: ConfigLoader, + *, + hydra_context: HydraContext, task_function: TaskFunction, + config: DictConfig, ) -> None: self.config = config - self.config_loader = config_loader + self.hydra_context = hydra_context self.task_function = task_function def launch( diff --git a/plugins/hydra_ray_launcher/hydra_plugins/hydra_ray_launcher/ray_launcher.py b/plugins/hydra_ray_launcher/hydra_plugins/hydra_ray_launcher/ray_launcher.py index 6c045c0dd4b..8dae23acee9 100644 --- a/plugins/hydra_ray_launcher/hydra_plugins/hydra_ray_launcher/ray_launcher.py +++ b/plugins/hydra_ray_launcher/hydra_plugins/hydra_ray_launcher/ray_launcher.py @@ -1,28 +1,28 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved from typing import Optional, Sequence -from hydra.core.config_loader import ConfigLoader from hydra.core.utils import JobReturn from hydra.plugins.launcher import Launcher -from hydra.types import TaskFunction +from hydra.types import HydraContext, TaskFunction from omegaconf import DictConfig class RayLauncher(Launcher): def __init__(self, ray: DictConfig) -> None: self.ray_cfg = ray - self.config: Optional[DictConfig] = None - self.config_loader: Optional[ConfigLoader] = None + self.hydra_context: Optional[HydraContext] = None self.task_function: Optional[TaskFunction] = None + self.config: Optional[DictConfig] = None def setup( self, - config: DictConfig, - config_loader: ConfigLoader, + *, + hydra_context: HydraContext, task_function: TaskFunction, + config: DictConfig, ) -> None: self.config = config - self.config_loader = config_loader + self.hydra_context = hydra_context self.task_function = task_function def launch( diff --git a/plugins/hydra_rq_launcher/hydra_plugins/hydra_rq_launcher/_core.py b/plugins/hydra_rq_launcher/hydra_plugins/hydra_rq_launcher/_core.py index 977ba60b996..fe54ac3bf04 100644 --- a/plugins/hydra_rq_launcher/hydra_plugins/hydra_rq_launcher/_core.py +++ b/plugins/hydra_rq_launcher/hydra_plugins/hydra_rq_launcher/_core.py @@ -16,7 +16,7 @@ run_job, setup_globals, ) -from hydra.types import TaskFunction +from hydra.types import HydraContext, TaskFunction from omegaconf import DictConfig, OmegaConf, open_dict from redis import Redis from rq import Queue # type: ignore @@ -27,6 +27,7 @@ def execute_job( + hydra_context: HydraContext, sweep_config: DictConfig, task_function: TaskFunction, singleton_state: Dict[Any, Any], @@ -37,8 +38,9 @@ def execute_job( HydraConfig.instance().set_config(sweep_config) ret = run_job( - config=sweep_config, + hydra_context=hydra_context, task_function=task_function, + config=sweep_config, job_dir_key="hydra.sweep.dir", job_subdir_key="hydra.sweep.subdir", ) @@ -56,8 +58,8 @@ def launch( """ setup_globals() assert launcher.config is not None - assert launcher.config_loader is not None assert launcher.task_function is not None + assert launcher.hydra_context is not None configure_log(launcher.config.hydra.hydra_logging, launcher.config.hydra.verbose) sweep_dir = Path(str(launcher.config.hydra.sweep.dir)) @@ -113,7 +115,7 @@ def launch( if enqueue_keywords["description"] is None: enqueue_keywords["description"] = description - sweep_config = launcher.config_loader.load_sweep_config( + sweep_config = launcher.hydra_context.config_loader.load_sweep_config( launcher.config, list(overrides) ) with open_dict(sweep_config): @@ -122,6 +124,7 @@ def launch( job = queue.enqueue( execute_job, + hydra_context=launcher.hydra_context, sweep_config=sweep_config, task_function=launcher.task_function, singleton_state=singleton_state, diff --git a/plugins/hydra_rq_launcher/hydra_plugins/hydra_rq_launcher/rq_launcher.py b/plugins/hydra_rq_launcher/hydra_plugins/hydra_rq_launcher/rq_launcher.py index 37c95605896..04bd9172710 100644 --- a/plugins/hydra_rq_launcher/hydra_plugins/hydra_rq_launcher/rq_launcher.py +++ b/plugins/hydra_rq_launcher/hydra_plugins/hydra_rq_launcher/rq_launcher.py @@ -2,10 +2,9 @@ import logging from typing import Any, Optional, Sequence -from hydra.core.config_loader import ConfigLoader from hydra.core.utils import JobReturn from hydra.plugins.launcher import Launcher -from hydra.types import TaskFunction +from hydra.types import HydraContext, TaskFunction from omegaconf import DictConfig, OmegaConf from .config import RQLauncherConf @@ -21,20 +20,21 @@ def __init__(self, **params: Any) -> None: https://python-rq.org """ self.config: Optional[DictConfig] = None - self.config_loader: Optional[ConfigLoader] = None self.task_function: Optional[TaskFunction] = None + self.hydra_context: Optional[HydraContext] = None self.rq = OmegaConf.structured(RQLauncherConf(**params)) def setup( self, - config: DictConfig, - config_loader: ConfigLoader, + *, + hydra_context: HydraContext, task_function: TaskFunction, + config: DictConfig, ) -> None: self.config = config - self.config_loader = config_loader self.task_function = task_function + self.hydra_context = hydra_context def launch( self, job_overrides: Sequence[Sequence[str]], initial_job_idx: int diff --git a/plugins/hydra_submitit_launcher/hydra_plugins/hydra_submitit_launcher/submitit_launcher.py b/plugins/hydra_submitit_launcher/hydra_plugins/hydra_submitit_launcher/submitit_launcher.py index 8e4618dd451..8aa8f6f60f3 100644 --- a/plugins/hydra_submitit_launcher/hydra_plugins/hydra_submitit_launcher/submitit_launcher.py +++ b/plugins/hydra_submitit_launcher/hydra_plugins/hydra_submitit_launcher/submitit_launcher.py @@ -4,11 +4,10 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Sequence -from hydra import TaskFunction -from hydra.core.config_loader import ConfigLoader from hydra.core.singleton import Singleton from hydra.core.utils import JobReturn, filter_overrides, run_job, setup_globals from hydra.plugins.launcher import Launcher +from hydra.types import HydraContext, TaskFunction from omegaconf import DictConfig, OmegaConf, open_dict from .config import BaseQueueConf @@ -28,18 +27,19 @@ def __init__(self, **params: Any) -> None: self.params[k] = v self.config: Optional[DictConfig] = None - self.config_loader: Optional[ConfigLoader] = None self.task_function: Optional[TaskFunction] = None self.sweep_configs: Optional[TaskFunction] = None + self.hydra_context: Optional[HydraContext] = None def setup( self, - config: DictConfig, - config_loader: ConfigLoader, + *, + hydra_context: HydraContext, task_function: TaskFunction, + config: DictConfig, ) -> None: self.config = config - self.config_loader = config_loader + self.hydra_context = hydra_context self.task_function = task_function def __call__( @@ -53,13 +53,13 @@ def __call__( # lazy import to ensure plugin discovery remains fast import submitit - assert self.config_loader is not None + assert self.hydra_context is not None assert self.config is not None assert self.task_function is not None Singleton.set_state(singleton_state) setup_globals() - sweep_config = self.config_loader.load_sweep_config( + sweep_config = self.hydra_context.config_loader.load_sweep_config( self.config, sweep_overrides ) @@ -69,8 +69,9 @@ def __call__( sweep_config.hydra.job.num = job_num return run_job( - config=sweep_config, + hydra_context=self.hydra_context, task_function=self.task_function, + config=sweep_config, job_dir_key=job_dir_key, job_subdir_key="hydra.sweep.subdir", ) diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index 6d045c0fff6..e64b503690a 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -23,7 +23,6 @@ """\ [HYDRA] Init custom_callback [HYDRA] custom_callback on_run_start - [HYDRA] Init custom_callback [JOB] custom_callback on_job_start [JOB] foo: bar @@ -42,7 +41,6 @@ [HYDRA] custom_callback on_multirun_start [HYDRA] Launching 1 jobs locally [HYDRA] \t#0 : foo=bar - [HYDRA] Init custom_callback [JOB] custom_callback on_job_start [JOB] foo: bar @@ -61,8 +59,6 @@ [HYDRA] Init callback_2 [HYDRA] callback_1 on_run_start [HYDRA] callback_2 on_run_start - [HYDRA] Init callback_1 - [HYDRA] Init callback_2 [JOB] callback_1 on_job_start [JOB] callback_2 on_job_start [JOB] {}