Skip to content

Commit

Permalink
add HydraContext
Browse files Browse the repository at this point in the history
  • Loading branch information
jieru-hu committed Apr 28, 2021
1 parent 2500514 commit 3100fc6
Show file tree
Hide file tree
Showing 28 changed files with 99 additions and 107 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from typing import Optional, Sequence

from hydra.core.utils import HydraContext
from hydra.core.config_loader import ConfigLoader
from hydra.core.config_store import ConfigStore
from hydra.core.singleton import Singleton
from hydra.core.utils import (
Expand Down Expand Up @@ -49,8 +48,8 @@ class LauncherConfig:
class ExampleLauncher(Launcher):
def __init__(self, foo: str, bar: str) -> None:
self.config: Optional[DictConfig] = None
self.config_loader: Optional[ConfigLoader] = None
self.task_function: Optional[TaskFunction] = None
self.hydra_context: Optional[HydraContext] = None

# foo and var are coming from the the plugin's configuration
self.foo = foo
Expand All @@ -64,7 +63,7 @@ def setup(
config: DictConfig,
) -> None:
self.config = config
self.config_loader = hydra_context.config_loader
self.hydra_context = hydra_context
self.task_function = task_function

def launch(
Expand All @@ -77,7 +76,7 @@ def launch(
"""
setup_globals()
assert self.config is not None
assert self.config_loader is not None
assert self.hydra_context is not None
assert self.task_function is not None

configure_log(self.config.hydra.hydra_logging, self.config.hydra.verbose)
Expand All @@ -93,7 +92,7 @@ def launch(
idx = initial_job_idx + idx
lst = " ".join(filter_overrides(overrides))
log.info(f"\t#{idx} : {lst}")
sweep_config = self.config_loader.load_sweep_config(
sweep_config = self.hydra_context.config_loader.load_sweep_config(
self.config, list(overrides)
)
with open_dict(sweep_config):
Expand All @@ -115,8 +114,9 @@ def launch(
Singleton.set_state(state)

ret = run_job(
config=sweep_config,
hydra_context=self.hydra_context,
task_function=self.task_function,
config=sweep_config,
job_dir_key="hydra.sweep.dir",
job_subdir_key="hydra.sweep.subdir",
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from typing import Any, Iterable, List, Optional, Sequence

from hydra.core.utils import HydraContext
from hydra.core.config_loader import ConfigLoader
from hydra.core.config_store import ConfigStore
from hydra.core.override_parser.overrides_parser import OverridesParser
from hydra.core.plugins import Plugins
Expand Down Expand Up @@ -46,6 +45,7 @@ def __init__(self, max_batch_size: int, foo: str, bar: str):
self.max_batch_size = max_batch_size
self.config: Optional[DictConfig] = None
self.launcher: Optional[Launcher] = None
self.hydra_context: Optional[HydraContext] = None
self.job_results = None
self.foo = foo
self.bar = bar
Expand All @@ -58,10 +58,10 @@ def setup(
config: DictConfig,
) -> None:
self.config = config
self.config_loader = hydra_context.config_loader
self.launcher = Plugins.instance().instantiate_launcher(
hydra_context=hydra_context, task_function=task_function, config=config
)
self.hydra_context = hydra_context

def sweep(self, arguments: List[str]) -> Any:
assert self.config is not None
Expand Down
10 changes: 3 additions & 7 deletions hydra/_internal/core_plugins/basic_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,9 @@

from omegaconf import DictConfig, open_dict

from hydra.core.utils import HydraContext
from hydra.core.callbacks import Callbacks
from hydra.core.config_loader import ConfigLoader
from hydra.core.config_store import ConfigStore
from hydra.core.utils import (
HydraContext,
JobReturn,
configure_log,
filter_overrides,
Expand All @@ -37,7 +35,6 @@ class BasicLauncher(Launcher):
def __init__(self) -> None:
super().__init__()
self.config: Optional[DictConfig] = None
self.config_loader: Optional[ConfigLoader] = None
self.task_function: Optional[TaskFunction] = None
self.hydra_context: Optional[HydraContext] = None

Expand All @@ -50,16 +47,15 @@ def setup(
) -> None:
self.config = config
self.hydra_context = hydra_context
self.config_loader = hydra_context.config_loader
self.task_function = task_function

def launch(
self, job_overrides: Sequence[Sequence[str]], initial_job_idx: int
) -> Sequence[JobReturn]:
setup_globals()
assert self.hydra_context is not None
assert self.config is not None
assert self.task_function is not None
assert self.config_loader is not None

configure_log(self.config.hydra.hydra_logging, self.config.hydra.verbose)
sweep_dir = self.config.hydra.sweep.dir
Expand All @@ -70,7 +66,7 @@ def launch(
idx = initial_job_idx + idx
lst = " ".join(filter_overrides(overrides))
log.info(f"\t#{idx} : {lst}")
sweep_config = self.config_loader.load_sweep_config(
sweep_config = self.hydra_context.config_loader.load_sweep_config(
self.config, list(overrides)
)
with open_dict(sweep_config):
Expand Down
15 changes: 8 additions & 7 deletions hydra/_internal/core_plugins/basic_sweeper.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,10 @@

from omegaconf import DictConfig, OmegaConf

from hydra.core.utils import HydraContext
from hydra.core.config_loader import ConfigLoader
from hydra.core.config_store import ConfigStore
from hydra.core.override_parser.overrides_parser import OverridesParser
from hydra.core.override_parser.types import Override
from hydra.core.utils import JobReturn
from hydra.core.utils import HydraContext, JobReturn
from hydra.errors import HydraException
from hydra.plugins.launcher import Launcher
from hydra.plugins.sweeper import Sweeper
Expand Down Expand Up @@ -65,7 +63,7 @@ def __init__(self, max_batch_size: Optional[int]) -> None:
self.batch_index = 0
self.max_batch_size = max_batch_size

self.config_loader: Optional[ConfigLoader] = None
self.hydra_context: Optional[HydraContext] = None
self.config: Optional[DictConfig] = None
self.launcher: Optional[Launcher] = None

Expand All @@ -78,11 +76,13 @@ def setup(
) -> None:
from hydra.core.plugins import Plugins

self.config_loader = hydra_context.config_loader
self.hydra_context = hydra_context
self.config = config

self.launcher = Plugins.instance().instantiate_launcher(
config=config, config_loader=self.config_loader, task_function=task_function
hydra_context=hydra_context,
task_function=task_function,
config=config,
)

@staticmethod
Expand Down Expand Up @@ -130,8 +130,9 @@ def split_arguments(
def sweep(self, arguments: List[str]) -> Any:
assert self.config is not None
assert self.launcher is not None
assert self.hydra_context is not None

parser = OverridesParser.create(config_loader=self.config_loader)
parser = OverridesParser.create(config_loader=self.hydra_context.config_loader)
overrides = parser.parse_overrides(arguments)

self.overrides = self.split_arguments(overrides, self.max_batch_size)
Expand Down
10 changes: 6 additions & 4 deletions hydra/_internal/hydra.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,22 @@
import sys
from argparse import ArgumentParser
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Callable, DefaultDict, List, Optional, Sequence, Type, Union

from omegaconf import Container, DictConfig, OmegaConf, flag_override, open_dict

from hydra._internal.utils import get_column_widths, run_and_report, Callbacks
from hydra._internal.utils import Callbacks, get_column_widths, run_and_report
from hydra.core.config_loader import ConfigLoader
from hydra.core.config_search_path import ConfigSearchPath
from hydra.core.plugins import Plugins
from hydra.core.utils import (
HydraContext,
JobReturn,
JobRuntime,
configure_log,
run_job,
setup_globals,
simple_stdout_log_config, HydraContext,
simple_stdout_log_config,
)
from hydra.plugins.completion_plugin import CompletionPlugin
from hydra.plugins.config_source import ConfigSource
Expand Down Expand Up @@ -95,7 +95,9 @@ def run(
callbacks.on_run_start(config=cfg, config_name=config_name)

ret = run_job(
hydra_context=HydraContext(callbacks=callbacks),
hydra_context=HydraContext(
config_loader=self.config_loader, callbacks=callbacks
),
task_function=task_function,
config=cfg,
job_dir_key="hydra.run.dir",
Expand Down
2 changes: 1 addition & 1 deletion hydra/_internal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,4 +648,4 @@ def on_job_end(
job_return=job_return,
reverse=True,
**kwargs,
)
)
3 changes: 1 addition & 2 deletions hydra/core/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@

from omegaconf import DictConfig

from hydra.core.utils import HydraContext
from hydra._internal.sources_registry import SourcesRegistry
from hydra.core.config_loader import ConfigLoader
from hydra.core.singleton import Singleton
from hydra.core.utils import HydraContext
from hydra.plugins.completion_plugin import CompletionPlugin
from hydra.plugins.config_source import ConfigSource
from hydra.plugins.launcher import Launcher
Expand Down
11 changes: 4 additions & 7 deletions hydra/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
from os.path import splitext
from pathlib import Path
from textwrap import dedent
from typing import Any, Dict, Optional, Sequence, Union, cast, TYPE_CHECKING
import warnings
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Union, cast

from omegaconf import DictConfig, OmegaConf, open_dict, read_write

Expand All @@ -21,6 +20,7 @@
from hydra.core.singleton import Singleton
from hydra.errors import HydraJobException
from hydra.types import TaskFunction

if TYPE_CHECKING:
from hydra._internal.utils import Callbacks

Expand All @@ -29,9 +29,8 @@

@dataclass
class HydraContext:
config_loader: Optional[ConfigLoader] = None
callbacks: Optional["Callbacks"] = None

config_loader: ConfigLoader
callbacks: "Callbacks"


def simple_stdout_log_config(level: int = logging.INFO) -> None:
Expand Down Expand Up @@ -253,8 +252,6 @@ def validate_config_path(config_path: Optional[str]) -> None:
raise ValueError(msg)




@contextmanager
def env_override(env: Dict[str, str]) -> Any:
"""Temporarily set environment variables inside the context manager and
Expand Down
2 changes: 0 additions & 2 deletions hydra/plugins/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,12 @@

from omegaconf import DictConfig

from hydra.core.config_loader import ConfigLoader
from hydra.core.utils import JobReturn, HydraContext
from hydra.types import TaskFunction

from .plugin import Plugin



class Launcher(Plugin):
@abstractmethod
def setup(
Expand Down
9 changes: 4 additions & 5 deletions hydra/plugins/sweeper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@
from abc import abstractmethod
from typing import Any, List, Sequence, Optional

from hydra.core.config_loader import ConfigLoader
from hydra.types import TaskFunction
from omegaconf import DictConfig
from .launcher import Launcher

from .plugin import Plugin
from hydra.core.utils import JobReturn, HydraContext
from hydra.core.utils import HydraContext


class Sweeper(Plugin):
Expand All @@ -21,7 +20,7 @@ class Sweeper(Plugin):
(where each job typically takes a different command line arguments)
"""

config_loader: Optional[ConfigLoader]
hydra_context: Optional[HydraContext]
config: Optional[DictConfig]
launcher: Optional[Launcher]

Expand Down Expand Up @@ -52,9 +51,9 @@ def validate_batch_is_legal(self, batch: Sequence[Sequence[str]]) -> None:
This repeat work the launcher will do, but as the launcher may be performing this in a different
process/machine it's important to do it here as well to detect failures early.
"""
assert self.config_loader is not None
assert self.hydra_context is not None
assert self.config is not None
for overrides in batch:
self.config_loader.load_sweep_config(
self.hydra_context.config_loader.load_sweep_config(
master_config=self.config, sweep_overrides=list(overrides)
)
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@
from ax.core import types as ax_types # type: ignore
from ax.exceptions.core import SearchSpaceExhausted # type: ignore
from ax.service.ax_client import AxClient # type: ignore

from hydra.core.utils import HydraContext
from hydra.core.config_loader import ConfigLoader
from hydra.core.override_parser.overrides_parser import OverridesParser
from hydra.core.override_parser.types import IntervalSweep, Override, Transformer
from hydra.core.plugins import Plugins
from hydra.core.utils import HydraContext
from hydra.plugins.launcher import Launcher
from hydra.plugins.sweeper import Sweeper
from hydra.types import TaskFunction
Expand Down Expand Up @@ -109,9 +107,9 @@ class CoreAxSweeper(Sweeper):
"""Class to interface with the Ax Platform"""

def __init__(self, ax_config: AxConfig, max_batch_size: Optional[int]):
self.config_loader: Optional[ConfigLoader] = None
self.config: Optional[DictConfig] = None
self.launcher: Optional[Launcher] = None
self.hydra_context: Optional[HydraContext] = None

self.job_results = None
self.experiment: ExperimentConfig = ax_config.experiment
Expand All @@ -138,7 +136,7 @@ def setup(
config: DictConfig,
) -> None:
self.config = config
self.config_loader = hydra_context.config_loader
self.hydra_context = hydra_context
self.launcher = Plugins.instance().instantiate_launcher(
config=config, hydra_context=hydra_context, task_function=task_function
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import List, Optional

from hydra.core.utils import HydraContext
from hydra.core.config_loader import ConfigLoader
from hydra.plugins.sweeper import Sweeper
from hydra.types import TaskFunction
from omegaconf import DictConfig
Expand Down
Loading

0 comments on commit 3100fc6

Please sign in to comment.