Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compatibility-breaking change in hydra 1.1.0dev7 #52

Merged
merged 8 commits into from
May 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ commands = pytest \


[testenv:min-hydra] # test against earliest supported version of hydra
deps = hydra-core==1.1.0dev5
deps = hydra-core==1.1.0dev7
{[testenv]deps}
basepython = python3.7

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
]
KEYWORDS = "machine learning research configuration scalable reproducible"
INSTALL_REQUIRES = [
"hydra-core >= 1.1.0dev5",
"hydra-core >= 1.1.0dev7",
"typing-extensions >= 3.7.4.1",
]
TESTS_REQUIRE = [
Expand Down
78 changes: 51 additions & 27 deletions src/hydra_zen/experimental/_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
from pathlib import Path
from typing import Any, Callable, List, Mapping, Optional, Union

from hydra import compose, initialize
from hydra._internal.callbacks import Callbacks
from hydra._internal.hydra import Hydra
from hydra._internal.utils import create_config_search_path
from hydra.core.config_store import ConfigStore
from hydra.core.global_hydra import GlobalHydra
from hydra.core.utils import JobReturn, run_job
from hydra.experimental import compose, initialize
from hydra.plugins.sweeper import Sweeper
from hydra.types import HydraContext
from omegaconf import DictConfig

from .._hydra_overloads import instantiate
Expand All @@ -19,15 +21,15 @@
def _store_config(
cfg: Union[DataClass, DictConfig, Mapping], config_name: str = "hydra_launch"
) -> str:
"""Generates a Structured Config and registers it in the ConfigStore.
"""Stores configuration object in Hydra's ConfigStore.

Parameters
----------
cfg: Union[DataClass, DictConfig, Mapping]
A configuration as a dataclass, configuration object, or a dictionary.

config_name: str (default: hydra_launch)
A default configuration name if available, otherwise a new object is
The configuration name used to store the configuration.

Returns
-------
Expand Down Expand Up @@ -64,19 +66,20 @@ def _load_config(
Returns
-------
config: DictConfig
The configuration object including Hydra configuration.

Notes
-----
This function uses Hydra's Compose API [1]_

References
----------
.. [1] https://hydra.cc/docs/experimental/compose_api
.. [2] https://hydra.cc/docs/configure_hydra/intro
.. [3] https://hydra.cc/docs/advanced/override_grammar/basic
.. [1] https://hydra.cc/docs/next/advanced/compose_api
.. [2] https://hydra.cc/docs/next/configure_hydra/intro
.. [3] https://hydra.cc/docs/next/advanced/override_grammar/basic
"""

with initialize():
with initialize(config_path=None):
task_cfg = compose(
config_name,
overrides=[] if overrides is None else overrides,
Expand All @@ -90,7 +93,9 @@ def hydra_run(
config: Union[DataClass, DictConfig, Mapping],
task_function: Callable[[DictConfig], Any],
overrides: Optional[List[str]] = None,
config_dir: Optional[Union[str, Path]] = None,
config_name: str = "hydra_run",
job_name: str = "hydra_run",
) -> JobReturn:
"""Launch a Hydra job defined by `task_function` using the configuration
provided in `config`.
Expand Down Expand Up @@ -121,6 +126,9 @@ def hydra_run(
config_dir: Optional[Union[str, Path]] (default: None)
Add configuration directories if needed.

config_name: str (default: "hydra_run")
Name of the stored configuration in Hydra's ConfigStore API.

job_name: str (default: "hydra_run")

Returns
Expand All @@ -136,8 +144,8 @@ def hydra_run(

References
----------
.. [1] https://hydra.cc/docs/advanced/override_grammar/basic
.. [2] https://hydra.cc/docs/configure_hydra/intro
.. [1] https://hydra.cc/docs/next/advanced/override_grammar/basic
.. [2] https://hydra.cc/docs/next/configure_hydra/intro

Examples
--------
Expand Down Expand Up @@ -186,14 +194,28 @@ def hydra_run(
config_name = _store_config(config, config_name)
task_cfg = _load_config(config_name=config_name, overrides=overrides)

if config_dir is not None:
config_dir = str(Path(config_dir).absolute())
search_path = create_config_search_path(config_dir)

hydra = Hydra.create_main_hydra2(task_name=job_name, config_search_path=search_path)

try:
callbacks = Callbacks(task_cfg)
callbacks.on_run_start(config=task_cfg, config_name=config_name)

job = run_job(
hydra_context=HydraContext(
config_loader=hydra.config_loader, callbacks=callbacks
),
config=task_cfg,
task_function=task_function,
job_dir_key="hydra.run.dir",
job_subdir_key=None,
configure_logging=False,
)

callbacks.on_run_end(config=task_cfg, config_name=config_name)
finally:
GlobalHydra.instance().clear()
return job
Expand All @@ -206,7 +228,7 @@ def hydra_multirun(
config_dir: Optional[Union[str, Path]] = None,
config_name: str = "hydra_multirun",
job_name: str = "hydra_multirun",
) -> List[JobReturn]:
) -> List[Any]:
"""Launch a Hydra multi-run ([1]_) job defined by `task_function` using the configuration
provided in `config`.

Expand Down Expand Up @@ -244,24 +266,21 @@ def hydra_multirun(
config_dir: Optional[Union[str, Path]] (default: None)
Add configuration directories if needed.

config_name: str (default: "hydra_run")
Name of the stored configuration in Hydra's ConfigStore API.

job_name: str (default: "hydra_multirun")

Returns
-------
result: List[List[JobReturn]]
The object storing the results of each Hydra experiment.
- overrides: From `overrides` and `multirun_overrides`
- return_value: The return value of the task function
- cfg: The configuration object sent to the task function
- hydra_cfg: The hydra configuration object
- working_dir: The experiment working directory
- task_name: The task name of the Hydra job
result: List[List[Any]]
The return values of all launched jobs (depends on the Sweeper implementation).

References
----------
.. [1] https://hydra.cc/docs/tutorials/basic/running_your_app/multi-run
.. [2] https://hydra.cc/docs/advanced/override_grammar/basic
.. [3] https://hydra.cc/docs/configure_hydra/intro
.. [2] https://hydra.cc/docs/next/advanced/override_grammar/basic
.. [3] https://hydra.cc/docs/next/configure_hydra/intro

Examples
--------
Expand Down Expand Up @@ -315,15 +334,16 @@ def hydra_multirun(
# Separate Hydra overrides from experiment overrides
hydra_overrides = []
_overrides = []
for o in overrides:
if o.startswith("hydra"):
hydra_overrides.append(o)
else:
_overrides.append(o)
if overrides is not None:
for o in overrides:
if o.startswith("hydra"):
hydra_overrides.append(o)
else:
_overrides.append(o)

# Only the hydra overrides are needed to extract the Hydra configuration for
# the launcher and sweepers.
# The sweeper handles the overides for each experiment
# The sweeper handles the overrides for each experiment
config_name = _store_config(config, config_name)
task_cfg = _load_config(config_name=config_name, overrides=hydra_overrides)

Expand All @@ -333,12 +353,15 @@ def hydra_multirun(

hydra = Hydra.create_main_hydra2(task_name=job_name, config_search_path=search_path)
try:
callbacks = Callbacks(task_cfg)
callbacks.on_multirun_start(config=task_cfg, config_name=config_name)

# Instantiate sweeper without using Hydra's Plugin discovery
sweeper = instantiate(task_cfg.hydra.sweeper)
assert isinstance(sweeper, Sweeper)
sweeper.setup(
config=task_cfg,
config_loader=hydra.config_loader,
hydra_context=HydraContext(hydra.config_loader, callbacks=callbacks),
task_function=task_function,
)

Expand All @@ -348,6 +371,7 @@ def hydra_multirun(
# just ensures repeats are removed
_overrides = list(set(_overrides))
job = sweeper.sweep(arguments=_overrides)
callbacks.on_multirun_end(config=task_cfg, config_name=config_name)
finally:
GlobalHydra.instance().clear()
return job
32 changes: 0 additions & 32 deletions src/hydra_zen/structured_configs/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,38 +142,6 @@ def get_obj_path(obj: Any) -> str:
return f"{module}.{name}"


def interpolated(func: Union[str, Callable], *literals: Any) -> str:
"""Produces an hydra-style interpolated string for calling the provided
function on the literals

Parameters
----------
func : Union[str, Callable]
The name of the function to use in the interpolation. The name
will be inferred if a function is provided.

literals : Any
Position-only literals to be fed to the function.

Notes
-----
See https://omegaconf.readthedocs.io/en/latest/usage.html#custom-interpolations
for more details about leveraging custom interpolations in omegaconf/hydra.

Examples
--------
>>> def add(x, y): return x + y
>>> interpolated(add, 1, 2)
'${add:1,2}'
"""
if not isinstance(func, str) and not hasattr(func, "__name__"): # pragma: no cover
raise TypeError(
f"`func` must be a string or have a `__name__` field, got: {func}"
)
name = func if isinstance(func, str) else func.__name__
return f"${{{name}:{','.join(repr(i) for i in literals)}}}"


NoneType = type(None)


Expand Down
115 changes: 115 additions & 0 deletions tests/experimental/test_callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from typing import NamedTuple

import pytest
from hydra.core.config_store import ConfigStore
from hydra.core.utils import JobReturn
from hydra.experimental.callback import Callback
from omegaconf import DictConfig

from hydra_zen import builds, instantiate
from hydra_zen.experimental import hydra_multirun, hydra_run


class Tracker(NamedTuple):
job_start: bool = False
job_end: bool = False
run_start: bool = False
run_end: bool = False
multirun_start: bool = False
multirun_end: bool = False


class CustomCallback(Callback):
JOB_START_CALLED = False
JOB_END_CALLED = False

RUN_START_CALLED = False
RUN_END_CALLED = False

MULTIRUN_START_CALLED = False
MULTIRUN_END_CALLED = False

def __init__(self, callback_name):
self.name = callback_name

def on_job_start(self, config: DictConfig, **kwargs) -> None:
CustomCallback.JOB_START_CALLED = True

def on_job_end(self, config: DictConfig, job_return: JobReturn, **kwargs) -> None:
CustomCallback.JOB_END_CALLED = True

def on_run_start(self, config: DictConfig, **kwargs) -> None:
CustomCallback.RUN_START_CALLED = True

def on_run_end(self, config: DictConfig, **kwargs) -> None:
CustomCallback.RUN_END_CALLED = True

def on_multirun_start(self, config: DictConfig, **kwargs) -> None:
CustomCallback.MULTIRUN_START_CALLED = True

def on_multirun_end(self, config: DictConfig, **kwargs) -> None:
CustomCallback.MULTIRUN_END_CALLED = True


cs = ConfigStore.instance()
cs.store(
group="hydra/callbacks",
name="test_callback",
node=dict(test_callback=builds(CustomCallback, callback_name="test")),
)


def tracker(x=CustomCallback):
# this will get called after the job and run have started
# but before they end
return Tracker(
job_start=x.JOB_START_CALLED,
job_end=x.JOB_END_CALLED,
run_start=x.RUN_START_CALLED,
run_end=x.RUN_END_CALLED,
multirun_start=x.MULTIRUN_START_CALLED,
multirun_end=x.MULTIRUN_END_CALLED,
)


@pytest.mark.usefixtures("cleandir")
@pytest.mark.parametrize("fn", [hydra_run, hydra_multirun])
def test_hydra_run_with_callback(fn):
# Tests that callback methods are called during appropriate
# stages
try:
is_multirun = fn is hydra_multirun

cfg = builds(tracker)

assert not any(tracker()) # ensures all flags are false

job = fn(
cfg, task_function=instantiate, overrides=["hydra/callbacks=test_callback"]
)

if is_multirun:
job = job[0][0]

tracked_mid_run: Tracker = job.return_value
assert tracked_mid_run.job_start is True
assert tracked_mid_run.run_start is not is_multirun
assert tracked_mid_run.multirun_start is is_multirun

assert tracked_mid_run.job_end is False
assert tracked_mid_run.run_end is False
assert tracked_mid_run.multirun_end is False

assert CustomCallback.JOB_END_CALLED is True
assert CustomCallback.RUN_END_CALLED is not is_multirun
assert CustomCallback.MULTIRUN_END_CALLED is is_multirun

finally:
CustomCallback.JOB_START_CALLED = False
CustomCallback.JOB_END_CALLED = False

CustomCallback.RUN_START_CALLED = False
CustomCallback.RUN_END_CALLED = False

CustomCallback.MULTIRUN_START_CALLED = False
CustomCallback.MULTIRUN_END_CALLED = False
Loading