Skip to content

Commit

Permalink
Add hydra context
Browse files Browse the repository at this point in the history
  • Loading branch information
jieru-hu committed Apr 28, 2021
1 parent 40efb3c commit 265feaf
Show file tree
Hide file tree
Showing 27 changed files with 195 additions and 133 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pathlib import Path
from typing import Optional, Sequence

from hydra.core.config_loader import ConfigLoader
from hydra.core.utils import HydraContext
from hydra.core.config_store import ConfigStore
from hydra.core.singleton import Singleton
from hydra.core.utils import (
Expand Down Expand Up @@ -48,21 +48,22 @@ class LauncherConfig:
class ExampleLauncher(Launcher):
def __init__(self, foo: str, bar: str) -> None:
self.config: Optional[DictConfig] = None
self.config_loader: Optional[ConfigLoader] = None
self.task_function: Optional[TaskFunction] = None
self.hydra_context: Optional[HydraContext] = None

# foo and var are coming from the the plugin's configuration
self.foo = foo
self.bar = bar

def setup(
self,
config: DictConfig,
config_loader: ConfigLoader,
*,
hydra_context: HydraContext,
task_function: TaskFunction,
config: DictConfig,
) -> None:
self.config = config
self.config_loader = config_loader
self.hydra_context = hydra_context
self.task_function = task_function

def launch(
Expand All @@ -75,7 +76,7 @@ def launch(
"""
setup_globals()
assert self.config is not None
assert self.config_loader is not None
assert self.hydra_context is not None
assert self.task_function is not None

configure_log(self.config.hydra.hydra_logging, self.config.hydra.verbose)
Expand All @@ -91,7 +92,7 @@ def launch(
idx = initial_job_idx + idx
lst = " ".join(filter_overrides(overrides))
log.info(f"\t#{idx} : {lst}")
sweep_config = self.config_loader.load_sweep_config(
sweep_config = self.hydra_context.config_loader.load_sweep_config(
self.config, list(overrides)
)
with open_dict(sweep_config):
Expand All @@ -113,8 +114,9 @@ def launch(
Singleton.set_state(state)

ret = run_job(
config=sweep_config,
hydra_context=self.hydra_context,
task_function=self.task_function,
config=sweep_config,
job_dir_key="hydra.sweep.dir",
job_subdir_key="hydra.sweep.subdir",
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pathlib import Path
from typing import Any, Iterable, List, Optional, Sequence

from hydra.core.config_loader import ConfigLoader
from hydra.core.utils import HydraContext
from hydra.core.config_store import ConfigStore
from hydra.core.override_parser.overrides_parser import OverridesParser
from hydra.core.plugins import Plugins
Expand Down Expand Up @@ -45,21 +45,23 @@ def __init__(self, max_batch_size: int, foo: str, bar: str):
self.max_batch_size = max_batch_size
self.config: Optional[DictConfig] = None
self.launcher: Optional[Launcher] = None
self.hydra_context: Optional[HydraContext] = None
self.job_results = None
self.foo = foo
self.bar = bar

def setup(
self,
config: DictConfig,
config_loader: ConfigLoader,
*,
hydra_context: HydraContext,
task_function: TaskFunction,
config: DictConfig,
) -> None:
self.config = config
self.config_loader = config_loader
self.launcher = Plugins.instance().instantiate_launcher(
config=config, config_loader=config_loader, task_function=task_function
hydra_context=hydra_context, task_function=task_function, config=config
)
self.hydra_context = hydra_context

def sweep(self, arguments: List[str]) -> Any:
assert self.config is not None
Expand Down
18 changes: 10 additions & 8 deletions hydra/_internal/core_plugins/basic_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

from omegaconf import DictConfig, open_dict

from hydra.core.config_loader import ConfigLoader
from hydra.core.config_store import ConfigStore
from hydra.core.utils import (
HydraContext,
JobReturn,
configure_log,
filter_overrides,
Expand All @@ -35,26 +35,27 @@ class BasicLauncher(Launcher):
def __init__(self) -> None:
super().__init__()
self.config: Optional[DictConfig] = None
self.config_loader: Optional[ConfigLoader] = None
self.task_function: Optional[TaskFunction] = None
self.hydra_context: Optional[HydraContext] = None

def setup(
self,
config: DictConfig,
config_loader: ConfigLoader,
*,
hydra_context: HydraContext,
task_function: TaskFunction,
config: DictConfig,
) -> None:
self.config = config
self.config_loader = config_loader
self.hydra_context = hydra_context
self.task_function = task_function

def launch(
self, job_overrides: Sequence[Sequence[str]], initial_job_idx: int
) -> Sequence[JobReturn]:
setup_globals()
assert self.hydra_context is not None
assert self.config is not None
assert self.task_function is not None
assert self.config_loader is not None

configure_log(self.config.hydra.hydra_logging, self.config.hydra.verbose)
sweep_dir = self.config.hydra.sweep.dir
Expand All @@ -65,15 +66,16 @@ def launch(
idx = initial_job_idx + idx
lst = " ".join(filter_overrides(overrides))
log.info(f"\t#{idx} : {lst}")
sweep_config = self.config_loader.load_sweep_config(
sweep_config = self.hydra_context.config_loader.load_sweep_config(
self.config, list(overrides)
)
with open_dict(sweep_config):
sweep_config.hydra.job.id = idx
sweep_config.hydra.job.num = idx
ret = run_job(
config=sweep_config,
hydra_context=self.hydra_context,
task_function=self.task_function,
config=sweep_config,
job_dir_key="hydra.sweep.dir",
job_subdir_key="hydra.sweep.subdir",
)
Expand Down
19 changes: 11 additions & 8 deletions hydra/_internal/core_plugins/basic_sweeper.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,10 @@

from omegaconf import DictConfig, OmegaConf

from hydra.core.config_loader import ConfigLoader
from hydra.core.config_store import ConfigStore
from hydra.core.override_parser.overrides_parser import OverridesParser
from hydra.core.override_parser.types import Override
from hydra.core.utils import JobReturn
from hydra.core.utils import HydraContext, JobReturn
from hydra.errors import HydraException
from hydra.plugins.launcher import Launcher
from hydra.plugins.sweeper import Sweeper
Expand Down Expand Up @@ -64,23 +63,26 @@ def __init__(self, max_batch_size: Optional[int]) -> None:
self.batch_index = 0
self.max_batch_size = max_batch_size

self.config_loader: Optional[ConfigLoader] = None
self.hydra_context: Optional[HydraContext] = None
self.config: Optional[DictConfig] = None
self.launcher: Optional[Launcher] = None

def setup(
self,
config: DictConfig,
config_loader: ConfigLoader,
*,
hydra_context: HydraContext,
task_function: TaskFunction,
config: DictConfig,
) -> None:
from hydra.core.plugins import Plugins

self.config_loader = config_loader
self.hydra_context = hydra_context
self.config = config

self.launcher = Plugins.instance().instantiate_launcher(
config=config, config_loader=config_loader, task_function=task_function
hydra_context=hydra_context,
task_function=task_function,
config=config,
)

@staticmethod
Expand Down Expand Up @@ -128,8 +130,9 @@ def split_arguments(
def sweep(self, arguments: List[str]) -> Any:
assert self.config is not None
assert self.launcher is not None
assert self.hydra_context is not None

parser = OverridesParser.create(config_loader=self.config_loader)
parser = OverridesParser.create(config_loader=self.hydra_context.config_loader)
overrides = parser.parse_overrides(arguments)

self.overrides = self.split_arguments(overrides, self.max_batch_size)
Expand Down
12 changes: 10 additions & 2 deletions hydra/_internal/hydra.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from hydra.core.config_search_path import ConfigSearchPath
from hydra.core.plugins import Plugins
from hydra.core.utils import (
HydraContext,
JobReturn,
JobRuntime,
configure_log,
Expand Down Expand Up @@ -95,8 +96,11 @@ def run(
callbacks.on_run_start(config=cfg, config_name=config_name)

ret = run_job(
config=cfg,
hydra_context=HydraContext(
config_loader=self.config_loader, callbacks=callbacks
),
task_function=task_function,
config=cfg,
job_dir_key="hydra.run.dir",
job_subdir_key=None,
configure_logging=with_log_configuration,
Expand Down Expand Up @@ -125,7 +129,11 @@ def multirun(
callbacks.on_multirun_start(config=cfg, config_name=config_name)

sweeper = Plugins.instance().instantiate_sweeper(
config=cfg, config_loader=self.config_loader, task_function=task_function
config=cfg,
hydra_context=HydraContext(
config_loader=self.config_loader, callbacks=callbacks
),
task_function=task_function,
)
task_overrides = OmegaConf.to_container(cfg.hydra.overrides.task, resolve=False)
assert isinstance(task_overrides, list)
Expand Down
16 changes: 9 additions & 7 deletions hydra/core/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from omegaconf import DictConfig

from hydra._internal.sources_registry import SourcesRegistry
from hydra.core.config_loader import ConfigLoader
from hydra.core.singleton import Singleton
from hydra.core.utils import HydraContext
from hydra.plugins.completion_plugin import CompletionPlugin
from hydra.plugins.config_source import ConfigSource
from hydra.plugins.launcher import Launcher
Expand Down Expand Up @@ -106,33 +106,35 @@ def is_in_toplevel_plugins_module(clazz: str) -> bool:

def instantiate_sweeper(
self,
config: DictConfig,
config_loader: ConfigLoader,
*,
hydra_context: HydraContext,
task_function: TaskFunction,
config: DictConfig,
) -> Sweeper:
Plugins.check_usage(self)
if config.hydra.sweeper is None:
raise RuntimeError("Hydra sweeper is not configured")
sweeper = self._instantiate(config.hydra.sweeper)
assert isinstance(sweeper, Sweeper)
sweeper.setup(
config=config, config_loader=config_loader, task_function=task_function
config=config, hydra_context=hydra_context, task_function=task_function
)
return sweeper

def instantiate_launcher(
self,
config: DictConfig,
config_loader: ConfigLoader,
*,
hydra_context: HydraContext,
task_function: TaskFunction,
config: DictConfig,
) -> Launcher:
Plugins.check_usage(self)
if config.hydra.launcher is None:
raise RuntimeError("Hydra launcher is not configured")
launcher = self._instantiate(config.hydra.launcher)
assert isinstance(launcher, Launcher)
launcher.setup(
config=config, config_loader=config_loader, task_function=task_function
config=config, hydra_context=hydra_context, task_function=task_function
)
return launcher

Expand Down
19 changes: 15 additions & 4 deletions hydra/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,28 @@
from os.path import splitext
from pathlib import Path
from textwrap import dedent
from typing import Any, Dict, Optional, Sequence, Union, cast
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Union, cast

from omegaconf import DictConfig, OmegaConf, open_dict, read_write

from hydra.core.config_loader import ConfigLoader
from hydra.core.hydra_config import HydraConfig
from hydra.core.singleton import Singleton
from hydra.errors import HydraJobException
from hydra.types import TaskFunction

if TYPE_CHECKING:
from hydra._internal.callbacks import Callbacks

log = logging.getLogger(__name__)


@dataclass
class HydraContext:
config_loader: ConfigLoader
callbacks: "Callbacks"


def simple_stdout_log_config(level: int = logging.INFO) -> None:
root = logging.getLogger()
root.setLevel(level)
Expand Down Expand Up @@ -83,15 +93,16 @@ def filter_overrides(overrides: Sequence[str]) -> Sequence[str]:


def run_job(
config: DictConfig,
*,
hydra_context: HydraContext,
task_function: TaskFunction,
config: DictConfig,
job_dir_key: str,
job_subdir_key: Optional[str],
configure_logging: bool = True,
) -> "JobReturn":
from hydra._internal.callbacks import Callbacks

callbacks = Callbacks(config)
callbacks = hydra_context.callbacks

old_cwd = os.getcwd()
orig_hydra_cfg = HydraConfig.instance().cfg
Expand Down
8 changes: 4 additions & 4 deletions hydra/plugins/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@

from omegaconf import DictConfig

from hydra.core.config_loader import ConfigLoader
from hydra.core.utils import JobReturn
from hydra.core.utils import JobReturn, HydraContext
from hydra.types import TaskFunction

from .plugin import Plugin
Expand All @@ -18,9 +17,10 @@ class Launcher(Plugin):
@abstractmethod
def setup(
self,
config: DictConfig,
config_loader: ConfigLoader,
*,
hydra_context: HydraContext,
task_function: TaskFunction,
config: DictConfig,
) -> None:
"""
Sets this launcher instance up.
Expand Down
Loading

0 comments on commit 265feaf

Please sign in to comment.