Skip to content

Commit

Permalink
Add HydraContext (#1581)
Browse files Browse the repository at this point in the history
  • Loading branch information
jieru-hu authored May 4, 2021
1 parent 901fb39 commit c6397d2
Show file tree
Hide file tree
Showing 29 changed files with 193 additions and 144 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pathlib import Path
from typing import Optional, Sequence

from hydra.core.config_loader import ConfigLoader
from hydra.types import HydraContext
from hydra.core.config_store import ConfigStore
from hydra.core.singleton import Singleton
from hydra.core.utils import (
Expand Down Expand Up @@ -48,21 +48,22 @@ class LauncherConfig:
class ExampleLauncher(Launcher):
def __init__(self, foo: str, bar: str) -> None:
self.config: Optional[DictConfig] = None
self.config_loader: Optional[ConfigLoader] = None
self.task_function: Optional[TaskFunction] = None
self.hydra_context: Optional[HydraContext] = None

# foo and var are coming from the the plugin's configuration
self.foo = foo
self.bar = bar

def setup(
self,
config: DictConfig,
config_loader: ConfigLoader,
*,
hydra_context: HydraContext,
task_function: TaskFunction,
config: DictConfig,
) -> None:
self.config = config
self.config_loader = config_loader
self.hydra_context = hydra_context
self.task_function = task_function

def launch(
Expand All @@ -75,7 +76,7 @@ def launch(
"""
setup_globals()
assert self.config is not None
assert self.config_loader is not None
assert self.hydra_context is not None
assert self.task_function is not None

configure_log(self.config.hydra.hydra_logging, self.config.hydra.verbose)
Expand All @@ -91,7 +92,7 @@ def launch(
idx = initial_job_idx + idx
lst = " ".join(filter_overrides(overrides))
log.info(f"\t#{idx} : {lst}")
sweep_config = self.config_loader.load_sweep_config(
sweep_config = self.hydra_context.config_loader.load_sweep_config(
self.config, list(overrides)
)
with open_dict(sweep_config):
Expand All @@ -113,8 +114,9 @@ def launch(
Singleton.set_state(state)

ret = run_job(
config=sweep_config,
hydra_context=self.hydra_context,
task_function=self.task_function,
config=sweep_config,
job_dir_key="hydra.sweep.dir",
job_subdir_key="hydra.sweep.subdir",
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pathlib import Path
from typing import Any, Iterable, List, Optional, Sequence

from hydra.core.config_loader import ConfigLoader
from hydra.types import HydraContext
from hydra.core.config_store import ConfigStore
from hydra.core.override_parser.overrides_parser import OverridesParser
from hydra.core.plugins import Plugins
Expand Down Expand Up @@ -45,21 +45,23 @@ def __init__(self, max_batch_size: int, foo: str, bar: str):
self.max_batch_size = max_batch_size
self.config: Optional[DictConfig] = None
self.launcher: Optional[Launcher] = None
self.hydra_context: Optional[HydraContext] = None
self.job_results = None
self.foo = foo
self.bar = bar

def setup(
self,
config: DictConfig,
config_loader: ConfigLoader,
*,
hydra_context: HydraContext,
task_function: TaskFunction,
config: DictConfig,
) -> None:
self.config = config
self.config_loader = config_loader
self.launcher = Plugins.instance().instantiate_launcher(
config=config, config_loader=config_loader, task_function=task_function
hydra_context=hydra_context, task_function=task_function, config=config
)
self.hydra_context = hydra_context

def sweep(self, arguments: List[str]) -> Any:
assert self.config is not None
Expand Down
19 changes: 10 additions & 9 deletions hydra/_internal/core_plugins/basic_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from omegaconf import DictConfig, open_dict

from hydra.core.config_loader import ConfigLoader
from hydra.core.config_store import ConfigStore
from hydra.core.utils import (
JobReturn,
Expand All @@ -16,7 +15,7 @@
setup_globals,
)
from hydra.plugins.launcher import Launcher
from hydra.types import TaskFunction
from hydra.types import HydraContext, TaskFunction

log = logging.getLogger(__name__)

Expand All @@ -35,26 +34,27 @@ class BasicLauncher(Launcher):
def __init__(self) -> None:
super().__init__()
self.config: Optional[DictConfig] = None
self.config_loader: Optional[ConfigLoader] = None
self.task_function: Optional[TaskFunction] = None
self.hydra_context: Optional[HydraContext] = None

def setup(
self,
config: DictConfig,
config_loader: ConfigLoader,
*,
hydra_context: HydraContext,
task_function: TaskFunction,
config: DictConfig,
) -> None:
self.config = config
self.config_loader = config_loader
self.hydra_context = hydra_context
self.task_function = task_function

def launch(
self, job_overrides: Sequence[Sequence[str]], initial_job_idx: int
) -> Sequence[JobReturn]:
setup_globals()
assert self.hydra_context is not None
assert self.config is not None
assert self.task_function is not None
assert self.config_loader is not None

configure_log(self.config.hydra.hydra_logging, self.config.hydra.verbose)
sweep_dir = self.config.hydra.sweep.dir
Expand All @@ -65,15 +65,16 @@ def launch(
idx = initial_job_idx + idx
lst = " ".join(filter_overrides(overrides))
log.info(f"\t#{idx} : {lst}")
sweep_config = self.config_loader.load_sweep_config(
sweep_config = self.hydra_context.config_loader.load_sweep_config(
self.config, list(overrides)
)
with open_dict(sweep_config):
sweep_config.hydra.job.id = idx
sweep_config.hydra.job.num = idx
ret = run_job(
config=sweep_config,
hydra_context=self.hydra_context,
task_function=self.task_function,
config=sweep_config,
job_dir_key="hydra.sweep.dir",
job_subdir_key="hydra.sweep.subdir",
)
Expand Down
19 changes: 11 additions & 8 deletions hydra/_internal/core_plugins/basic_sweeper.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,14 @@

from omegaconf import DictConfig, OmegaConf

from hydra.core.config_loader import ConfigLoader
from hydra.core.config_store import ConfigStore
from hydra.core.override_parser.overrides_parser import OverridesParser
from hydra.core.override_parser.types import Override
from hydra.core.utils import JobReturn
from hydra.errors import HydraException
from hydra.plugins.launcher import Launcher
from hydra.plugins.sweeper import Sweeper
from hydra.types import TaskFunction
from hydra.types import HydraContext, TaskFunction


@dataclass
Expand Down Expand Up @@ -64,23 +63,26 @@ def __init__(self, max_batch_size: Optional[int]) -> None:
self.batch_index = 0
self.max_batch_size = max_batch_size

self.config_loader: Optional[ConfigLoader] = None
self.hydra_context: Optional[HydraContext] = None
self.config: Optional[DictConfig] = None
self.launcher: Optional[Launcher] = None

def setup(
self,
config: DictConfig,
config_loader: ConfigLoader,
*,
hydra_context: HydraContext,
task_function: TaskFunction,
config: DictConfig,
) -> None:
from hydra.core.plugins import Plugins

self.config_loader = config_loader
self.hydra_context = hydra_context
self.config = config

self.launcher = Plugins.instance().instantiate_launcher(
config=config, config_loader=config_loader, task_function=task_function
hydra_context=hydra_context,
task_function=task_function,
config=config,
)

@staticmethod
Expand Down Expand Up @@ -128,8 +130,9 @@ def split_arguments(
def sweep(self, arguments: List[str]) -> Any:
assert self.config is not None
assert self.launcher is not None
assert self.hydra_context is not None

parser = OverridesParser.create(config_loader=self.config_loader)
parser = OverridesParser.create(config_loader=self.hydra_context.config_loader)
overrides = parser.parse_overrides(arguments)

self.overrides = self.split_arguments(overrides, self.max_batch_size)
Expand Down
13 changes: 10 additions & 3 deletions hydra/_internal/hydra.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from hydra.plugins.launcher import Launcher
from hydra.plugins.search_path_plugin import SearchPathPlugin
from hydra.plugins.sweeper import Sweeper
from hydra.types import RunMode, TaskFunction
from hydra.types import HydraContext, RunMode, TaskFunction

from ..core.default_element import DefaultsTreeNode, InputDefault
from .callbacks import Callbacks
Expand Down Expand Up @@ -95,8 +95,11 @@ def run(
callbacks.on_run_start(config=cfg, config_name=config_name)

ret = run_job(
config=cfg,
hydra_context=HydraContext(
config_loader=self.config_loader, callbacks=callbacks
),
task_function=task_function,
config=cfg,
job_dir_key="hydra.run.dir",
job_subdir_key=None,
configure_logging=with_log_configuration,
Expand Down Expand Up @@ -125,7 +128,11 @@ def multirun(
callbacks.on_multirun_start(config=cfg, config_name=config_name)

sweeper = Plugins.instance().instantiate_sweeper(
config=cfg, config_loader=self.config_loader, task_function=task_function
config=cfg,
hydra_context=HydraContext(
config_loader=self.config_loader, callbacks=callbacks
),
task_function=task_function,
)
task_overrides = OmegaConf.to_container(cfg.hydra.overrides.task, resolve=False)
assert isinstance(task_overrides, list)
Expand Down
17 changes: 9 additions & 8 deletions hydra/core/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,14 @@
from omegaconf import DictConfig

from hydra._internal.sources_registry import SourcesRegistry
from hydra.core.config_loader import ConfigLoader
from hydra.core.singleton import Singleton
from hydra.plugins.completion_plugin import CompletionPlugin
from hydra.plugins.config_source import ConfigSource
from hydra.plugins.launcher import Launcher
from hydra.plugins.plugin import Plugin
from hydra.plugins.search_path_plugin import SearchPathPlugin
from hydra.plugins.sweeper import Sweeper
from hydra.types import TaskFunction
from hydra.types import HydraContext, TaskFunction
from hydra.utils import instantiate


Expand Down Expand Up @@ -106,33 +105,35 @@ def is_in_toplevel_plugins_module(clazz: str) -> bool:

def instantiate_sweeper(
self,
config: DictConfig,
config_loader: ConfigLoader,
*,
hydra_context: HydraContext,
task_function: TaskFunction,
config: DictConfig,
) -> Sweeper:
Plugins.check_usage(self)
if config.hydra.sweeper is None:
raise RuntimeError("Hydra sweeper is not configured")
sweeper = self._instantiate(config.hydra.sweeper)
assert isinstance(sweeper, Sweeper)
sweeper.setup(
config=config, config_loader=config_loader, task_function=task_function
config=config, hydra_context=hydra_context, task_function=task_function
)
return sweeper

def instantiate_launcher(
self,
config: DictConfig,
config_loader: ConfigLoader,
*,
hydra_context: HydraContext,
task_function: TaskFunction,
config: DictConfig,
) -> Launcher:
Plugins.check_usage(self)
if config.hydra.launcher is None:
raise RuntimeError("Hydra launcher is not configured")
launcher = self._instantiate(config.hydra.launcher)
assert isinstance(launcher, Launcher)
launcher.setup(
config=config, config_loader=config_loader, task_function=task_function
config=config, hydra_context=hydra_context, task_function=task_function
)
return launcher

Expand Down
9 changes: 5 additions & 4 deletions hydra/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from hydra.core.hydra_config import HydraConfig
from hydra.core.singleton import Singleton
from hydra.errors import HydraJobException
from hydra.types import TaskFunction
from hydra.types import HydraContext, TaskFunction

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -83,15 +83,16 @@ def filter_overrides(overrides: Sequence[str]) -> Sequence[str]:


def run_job(
config: DictConfig,
*,
hydra_context: HydraContext,
task_function: TaskFunction,
config: DictConfig,
job_dir_key: str,
job_subdir_key: Optional[str],
configure_logging: bool = True,
) -> "JobReturn":
from hydra._internal.callbacks import Callbacks

callbacks = Callbacks(config)
callbacks = hydra_context.callbacks

old_cwd = os.getcwd()
orig_hydra_cfg = HydraConfig.instance().cfg
Expand Down
9 changes: 5 additions & 4 deletions hydra/plugins/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@

from omegaconf import DictConfig

from hydra.core.config_loader import ConfigLoader
from hydra.core.utils import JobReturn
from hydra.types import TaskFunction

from hydra.types import TaskFunction, HydraContext

from .plugin import Plugin

Expand All @@ -18,9 +18,10 @@ class Launcher(Plugin):
@abstractmethod
def setup(
self,
config: DictConfig,
config_loader: ConfigLoader,
*,
hydra_context: HydraContext,
task_function: TaskFunction,
config: DictConfig,
) -> None:
"""
Sets this launcher instance up.
Expand Down
Loading

0 comments on commit c6397d2

Please sign in to comment.