Skip to content

Commit

Permalink
Raise warning on oudated methods call
Browse files Browse the repository at this point in the history
  • Loading branch information
jieru-hu committed May 3, 2021
1 parent 059c0b4 commit 0d273e7
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 17 deletions.
7 changes: 4 additions & 3 deletions hydra/_internal/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import warnings
from typing import TYPE_CHECKING, Any

from omegaconf import DictConfig
from omegaconf import DictConfig, OmegaConf

if TYPE_CHECKING:
from hydra.core.utils import JobReturn
Expand All @@ -13,8 +13,9 @@ def __init__(self, config: DictConfig) -> None:
self.callbacks = []
from hydra.utils import instantiate

for params in config.hydra.callbacks.values():
self.callbacks.append(instantiate(params))
if OmegaConf.select(config, "hydra.callbacks"):
for params in config.hydra.callbacks.values():
self.callbacks.append(instantiate(params))

def _notify(self, function_name: str, reverse: bool = False, **kwargs: Any) -> None:
callbacks = reversed(self.callbacks) if reverse else self.callbacks
Expand Down
56 changes: 47 additions & 9 deletions hydra/core/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
from collections import defaultdict
from dataclasses import dataclass, field
from timeit import default_timer as timer
from typing import Any, Dict, List, Optional, Tuple, Type
from typing import Any, Dict, List, Optional, Tuple, Type, Union

from omegaconf import DictConfig

from hydra._internal.sources_registry import SourcesRegistry
from hydra.core.config_loader import ConfigLoader
from hydra.core.singleton import Singleton
from hydra.plugins.completion_plugin import CompletionPlugin
from hydra.plugins.config_source import ConfigSource
Expand Down Expand Up @@ -105,36 +106,73 @@ def is_in_toplevel_plugins_module(clazz: str) -> bool:

def instantiate_sweeper(
self,
*,
hydra_context: HydraContext,
task_function: TaskFunction,
config: DictConfig,
config_loader: Optional[ConfigLoader] = None,
hydra_context: Optional[HydraContext] = None,
) -> Sweeper:
Plugins.check_usage(self)
if config.hydra.sweeper is None:
raise RuntimeError("Hydra sweeper is not configured")
sweeper = self._instantiate(config.hydra.sweeper)
assert isinstance(sweeper, Sweeper)
sweeper.setup(
config=config, hydra_context=hydra_context, task_function=task_function
sweeper = self._setup_plugin(
plugin=sweeper,
config=config,
task_function=task_function,
config_loader=config_loader,
hydra_context=hydra_context,
)
assert isinstance(sweeper, Sweeper)
return sweeper

@staticmethod
def _setup_plugin(
plugin: Union[Sweeper, Launcher],
config: DictConfig,
task_function: TaskFunction,
config_loader: Optional[ConfigLoader] = None,
hydra_context: Optional[HydraContext] = None,
) -> Union[Sweeper, Launcher]:
assert (
config_loader is not None or hydra_context is not None
), "config_loader and hydra_context cannot both be None"
if hydra_context is not None:
plugin.setup(
config=config, hydra_context=hydra_context, task_function=task_function
)
else:
warnings.warn(
message="\n"
"\tPlugin's setup method has changed in Hydra 1.1.\n"
"\tSupport for the old style will be removed in Hydra 1.2.\n"
)
plugin.setup(
config=config, config_loader=config_loader, task_function=task_function
) # type: ignore
return plugin

def instantiate_launcher(
self,
*,
hydra_context: HydraContext,
task_function: TaskFunction,
config: DictConfig,
config_loader: Optional[ConfigLoader] = None,
hydra_context: Optional[HydraContext] = None,
) -> Launcher:
Plugins.check_usage(self)
if config.hydra.launcher is None:
raise RuntimeError("Hydra launcher is not configured")

launcher = self._instantiate(config.hydra.launcher)
assert isinstance(launcher, Launcher)
launcher.setup(
config=config, hydra_context=hydra_context, task_function=task_function
launcher = self._setup_plugin(
plugin=launcher,
config=config,
task_function=task_function,
config_loader=config_loader,
hydra_context=hydra_context,
)
assert isinstance(launcher, Launcher)
return launcher

@staticmethod
Expand Down
29 changes: 24 additions & 5 deletions hydra/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,25 @@
import os
import re
import sys
import warnings
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from os.path import splitext
from pathlib import Path
from textwrap import dedent
from typing import Any, Dict, Optional, Sequence, Union, cast
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Union, cast

from omegaconf import DictConfig, OmegaConf, open_dict, read_write

from hydra.core.hydra_config import HydraConfig
from hydra.core.singleton import Singleton
from hydra.errors import HydraJobException

if TYPE_CHECKING:
from hydra._internal.callbacks import Callbacks

from hydra.types import HydraContext, TaskFunction

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -82,17 +87,31 @@ def filter_overrides(overrides: Sequence[str]) -> Sequence[str]:
return [x for x in overrides if not x.startswith("hydra.")]


def _get_callbacks_for_run_job(hydra_context: Optional[HydraContext]) -> "Callbacks":
if hydra_context is None:
warnings.warn(
message="\n"
"\trun_job's signature has changed in Hydra 1.1.\n"
"\tSupport for the old style will be removed in Hydra 1.2.\n"
)
from hydra._internal.callbacks import Callbacks

callbacks = Callbacks(OmegaConf.create())
else:
callbacks = hydra_context.callbacks

return callbacks


def run_job(
*,
hydra_context: HydraContext,
task_function: TaskFunction,
config: DictConfig,
job_dir_key: str,
job_subdir_key: Optional[str],
configure_logging: bool = True,
hydra_context: Optional[HydraContext] = None,
) -> "JobReturn":

callbacks = hydra_context.callbacks
callbacks = _get_callbacks_for_run_job(hydra_context)

old_cwd = os.getcwd()
orig_hydra_cfg = HydraConfig.instance().cfg
Expand Down
74 changes: 74 additions & 0 deletions tests/test_hydra_context_warnings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import Any, Optional
from unittest.mock import MagicMock, Mock

from omegaconf import DictConfig
from pytest import mark, raises, warns

from hydra import TaskFunction
from hydra.core.config_loader import ConfigLoader
from hydra.core.plugins import Plugins
from hydra.core.utils import _get_callbacks_for_run_job
from hydra.plugins.sweeper import Sweeper
from hydra.test_utils.test_utils import chdir_hydra_root
from hydra.types import HydraContext

chdir_hydra_root()

plugins = Plugins.instance()


@mark.parametrize(
"config_loader, hydra_context, expected",
[
(None, None, AssertionError),
(Mock(ConfigLoader), None, UserWarning()),
(None, Mock(HydraContext), None),
],
)
def test_setup_plugin(
config_loader: Optional[ConfigLoader],
hydra_context: Optional[HydraContext],
expected: Any,
) -> None:
plugin = Mock(spec=Sweeper)
plugin.setup = MagicMock()
config = Mock(spec=DictConfig)
task_function = Mock(spec=TaskFunction)
msg = (
"\n"
"\tPlugin's setup method has changed in Hydra 1.1.\n"
"\tSupport for the old style will be removed in Hydra 1.2.\n"
)
if expected is None:
plugins._setup_plugin(
plugin, config, task_function, config_loader, hydra_context
)
plugin.setup.assert_called_with(
config=config, hydra_context=hydra_context, task_function=task_function
)
elif isinstance(expected, UserWarning):
with warns(expected_warning=UserWarning, match=msg):
plugins._setup_plugin(
plugin, config, task_function, config_loader, hydra_context
)
plugin.setup.assert_called_with(
config=config, config_loader=config_loader, task_function=task_function
)
else:
with raises(expected):
plugins._setup_plugin(
plugin, config, task_function, config_loader, hydra_context
)
plugin.setup.assert_not_called()


def test_run_job() -> None:
hydra_context = None
msg = (
"\n"
"\trun_job's signature has changed in Hydra 1.1.\n"
"\tSupport for the old style will be removed in Hydra 1.2.\n"
)
with warns(expected_warning=UserWarning, match=msg):
_get_callbacks_for_run_job(hydra_context)

0 comments on commit 0d273e7

Please sign in to comment.