Skip to content

Commit

Permalink
Issue warning on outdated API usage
Browse files Browse the repository at this point in the history
  • Loading branch information
jieru-hu committed May 3, 2021
1 parent 059c0b4 commit 4c73fab
Show file tree
Hide file tree
Showing 6 changed files with 186 additions and 24 deletions.
7 changes: 4 additions & 3 deletions hydra/_internal/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import warnings
from typing import TYPE_CHECKING, Any

from omegaconf import DictConfig
from omegaconf import DictConfig, OmegaConf

if TYPE_CHECKING:
from hydra.core.utils import JobReturn
Expand All @@ -13,8 +13,9 @@ def __init__(self, config: DictConfig) -> None:
self.callbacks = []
from hydra.utils import instantiate

for params in config.hydra.callbacks.values():
self.callbacks.append(instantiate(params))
if OmegaConf.select(config, "hydra.callbacks"):
for params in config.hydra.callbacks.values():
self.callbacks.append(instantiate(params))

def _notify(self, function_name: str, reverse: bool = False, **kwargs: Any) -> None:
callbacks = reversed(self.callbacks) if reverse else self.callbacks
Expand Down
18 changes: 12 additions & 6 deletions hydra/_internal/core_plugins/basic_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from omegaconf import DictConfig, open_dict

from hydra.core.config_loader import ConfigLoader
from hydra.core.config_store import ConfigStore
from hydra.core.utils import (
JobReturn,
Expand Down Expand Up @@ -36,23 +37,30 @@ def __init__(self) -> None:
self.config: Optional[DictConfig] = None
self.task_function: Optional[TaskFunction] = None
self.hydra_context: Optional[HydraContext] = None
self.config_loader: Optional[ConfigLoader] = None

def setup(
self,
*,
hydra_context: HydraContext,
task_function: TaskFunction,
config: DictConfig,
hydra_context: Optional[HydraContext] = None,
config_loader: Optional[ConfigLoader] = None,
) -> None:
self.config = config
self.hydra_context = hydra_context
self.config_loader = config_loader
self.task_function = task_function

def launch(
self, job_overrides: Sequence[Sequence[str]], initial_job_idx: int
) -> Sequence[JobReturn]:
setup_globals()
assert self.hydra_context is not None
config_loader = (
self.hydra_context.config_loader
if self.hydra_context
else self.config_loader
)
assert config_loader is not None
assert self.config is not None
assert self.task_function is not None

Expand All @@ -65,9 +73,7 @@ def launch(
idx = initial_job_idx + idx
lst = " ".join(filter_overrides(overrides))
log.info(f"\t#{idx} : {lst}")
sweep_config = self.hydra_context.config_loader.load_sweep_config(
self.config, list(overrides)
)
sweep_config = config_loader.load_sweep_config(self.config, list(overrides))
with open_dict(sweep_config):
sweep_config.hydra.job.id = idx
sweep_config.hydra.job.num = idx
Expand Down
67 changes: 59 additions & 8 deletions hydra/core/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
import warnings
from collections import defaultdict
from dataclasses import dataclass, field
from inspect import signature
from timeit import default_timer as timer
from typing import Any, Dict, List, Optional, Tuple, Type

from omegaconf import DictConfig

from hydra._internal.sources_registry import SourcesRegistry
from hydra.core.config_loader import ConfigLoader
from hydra.core.singleton import Singleton
from hydra.plugins.completion_plugin import CompletionPlugin
from hydra.plugins.config_source import ConfigSource
Expand Down Expand Up @@ -103,6 +105,46 @@ def is_in_toplevel_plugins_module(clazz: str) -> bool:
"hydra._internal.core_plugins."
)

@staticmethod
def _setup_plugin(
plugin: Any,
task_function: TaskFunction,
config: DictConfig,
config_loader: Optional[ConfigLoader] = None,
hydra_context: Optional[HydraContext] = None,
) -> Any:
assert isinstance(plugin, Sweeper) or isinstance(plugin, Launcher)
assert (
config_loader is not None or hydra_context is not None
), "config_loader and hydra_context cannot both be None"

param_keys = signature(plugin.setup).parameters.keys()
if "config_loader" in param_keys and "hydra_context" not in param_keys:
warnings.warn(
message=(
"\n"
"\tPlugin's setup() signature has changed in Hydra 1.1.\n"
"\tSupport for the old style will be removed in Hydra 1.2.\n"
),
category=UserWarning,
)
config_loader = (
config_loader
if config_loader is not None
else hydra_context.config_loader # type: ignore
)
plugin.setup(
config=config,
config_loader=config_loader,
task_function=task_function,
) # type: ignore
else:
assert hydra_context is not None
plugin.setup(
config=config, hydra_context=hydra_context, task_function=task_function
)
return plugin

def instantiate_sweeper(
self,
*,
Expand All @@ -114,27 +156,36 @@ def instantiate_sweeper(
if config.hydra.sweeper is None:
raise RuntimeError("Hydra sweeper is not configured")
sweeper = self._instantiate(config.hydra.sweeper)
assert isinstance(sweeper, Sweeper)
sweeper.setup(
config=config, hydra_context=hydra_context, task_function=task_function
sweeper = self._setup_plugin(
plugin=sweeper,
task_function=task_function,
config=config,
config_loader=None,
hydra_context=hydra_context,
)
assert isinstance(sweeper, Sweeper)
return sweeper

def instantiate_launcher(
self,
*,
hydra_context: HydraContext,
task_function: TaskFunction,
config: DictConfig,
config_loader: Optional[ConfigLoader] = None,
hydra_context: Optional[HydraContext] = None,
) -> Launcher:
Plugins.check_usage(self)
if config.hydra.launcher is None:
raise RuntimeError("Hydra launcher is not configured")

launcher = self._instantiate(config.hydra.launcher)
assert isinstance(launcher, Launcher)
launcher.setup(
config=config, hydra_context=hydra_context, task_function=task_function
launcher = self._setup_plugin(
plugin=launcher,
config=config,
task_function=task_function,
config_loader=config_loader,
hydra_context=hydra_context,
)
assert isinstance(launcher, Launcher)
return launcher

@staticmethod
Expand Down
29 changes: 24 additions & 5 deletions hydra/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,25 @@
import os
import re
import sys
import warnings
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from os.path import splitext
from pathlib import Path
from textwrap import dedent
from typing import Any, Dict, Optional, Sequence, Union, cast
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Union, cast

from omegaconf import DictConfig, OmegaConf, open_dict, read_write

from hydra.core.hydra_config import HydraConfig
from hydra.core.singleton import Singleton
from hydra.errors import HydraJobException

if TYPE_CHECKING:
from hydra._internal.callbacks import Callbacks

from hydra.types import HydraContext, TaskFunction

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -82,17 +87,31 @@ def filter_overrides(overrides: Sequence[str]) -> Sequence[str]:
return [x for x in overrides if not x.startswith("hydra.")]


def _get_callbacks_for_run_job(hydra_context: Optional[HydraContext]) -> "Callbacks":
if hydra_context is None:
warnings.warn(
message="\n"
"\trun_job's signature has changed in Hydra 1.1.\n"
"\tSupport for the old style will be removed in Hydra 1.2.\n"
)
from hydra._internal.callbacks import Callbacks

callbacks = Callbacks(OmegaConf.create())
else:
callbacks = hydra_context.callbacks

return callbacks


def run_job(
*,
hydra_context: HydraContext,
task_function: TaskFunction,
config: DictConfig,
job_dir_key: str,
job_subdir_key: Optional[str],
configure_logging: bool = True,
hydra_context: Optional[HydraContext] = None,
) -> "JobReturn":

callbacks = hydra_context.callbacks
callbacks = _get_callbacks_for_run_job(hydra_context)

old_cwd = os.getcwd()
orig_hydra_cfg = HydraConfig.instance().cfg
Expand Down
10 changes: 8 additions & 2 deletions hydra/plugins/sweeper.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,15 @@ def validate_batch_is_legal(self, batch: Sequence[Sequence[str]]) -> None:
This repeat work the launcher will do, but as the launcher may be performing this in a different
process/machine it's important to do it here as well to detect failures early.
"""
assert self.hydra_context is not None
config_loader = (
self.hydra_context.config_loader
if hasattr(self, "hydra_context") and self.hydra_context is not None
else self.config_loader # type: ignore
)
assert config_loader is not None

assert self.config is not None
for overrides in batch:
self.hydra_context.config_loader.load_sweep_config(
config_loader.load_sweep_config(
master_config=self.config, sweep_overrides=list(overrides)
)
79 changes: 79 additions & 0 deletions tests/test_hydra_context_warnings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import re
from typing import Any, Optional
from unittest.mock import Mock

from omegaconf import DictConfig
from pytest import mark, raises, warns

from hydra import TaskFunction
from hydra.core.config_loader import ConfigLoader
from hydra.core.plugins import Plugins
from hydra.core.utils import _get_callbacks_for_run_job
from hydra.plugins.sweeper import Sweeper
from hydra.test_utils.test_utils import chdir_hydra_root
from hydra.types import HydraContext

chdir_hydra_root()

plugins = Plugins.instance()


@mark.parametrize(
"config_loader, hydra_context, setup_method, expected",
[
(None, None, lambda foo, bar: None, AssertionError),
(
Mock(ConfigLoader),
None,
lambda config_loader, config, task_function: None,
UserWarning(),
),
(
Mock(spec=ConfigLoader),
Mock(spec=HydraContext),
lambda hydra_context, config, task_function: None,
None,
),
],
)
def test_setup_plugin(
config_loader: Optional[ConfigLoader],
hydra_context: Optional[HydraContext],
setup_method: Any,
expected: Any,
) -> None:
plugin = Mock(spec=Sweeper)
plugin.setup = setup_method
config = Mock(spec=DictConfig)
task_function = Mock(spec=TaskFunction)
msg = (
"\n"
"\tPlugin's setup() signature has changed in Hydra 1.1.\n"
"\tSupport for the old style will be removed in Hydra 1.2.\n"
)
if expected is None:
plugins._setup_plugin(
plugin, config, task_function, config_loader, hydra_context
)
elif isinstance(expected, UserWarning):
with warns(expected_warning=UserWarning, match=re.escape(msg)):
plugins._setup_plugin(
plugin, task_function, config, config_loader, hydra_context
)
else:
with raises(expected):
plugins._setup_plugin(
plugin, config, task_function, config_loader, hydra_context
)


def test_run_job() -> None:
hydra_context = None
msg = (
"\n"
"\trun_job's signature has changed in Hydra 1.1.\n"
"\tSupport for the old style will be removed in Hydra 1.2.\n"
)
with warns(expected_warning=UserWarning, match=msg):
_get_callbacks_for_run_job(hydra_context)

0 comments on commit 4c73fab

Please sign in to comment.