diff --git a/setup.cfg b/setup.cfg index 7efd87aed..ac8456b60 100644 --- a/setup.cfg +++ b/setup.cfg @@ -41,7 +41,7 @@ commands = pytest \ [testenv:min-hydra] # test against earliest supported version of hydra -deps = hydra-core==1.1.0dev5 +deps = hydra-core==1.1.0dev7 {[testenv]deps} basepython = python3.7 diff --git a/setup.py b/setup.py index 336cab30b..549d28cf2 100644 --- a/setup.py +++ b/setup.py @@ -22,7 +22,7 @@ ] KEYWORDS = "machine learning research configuration scalable reproducible" INSTALL_REQUIRES = [ - "hydra-core >= 1.1.0dev5", + "hydra-core >= 1.1.0dev7", "typing-extensions >= 3.7.4.1", ] TESTS_REQUIRE = [ diff --git a/src/hydra_zen/experimental/_implementations.py b/src/hydra_zen/experimental/_implementations.py index 118cd49f8..11abcbe71 100644 --- a/src/hydra_zen/experimental/_implementations.py +++ b/src/hydra_zen/experimental/_implementations.py @@ -3,13 +3,15 @@ from pathlib import Path from typing import Any, Callable, List, Mapping, Optional, Union +from hydra import compose, initialize +from hydra._internal.callbacks import Callbacks from hydra._internal.hydra import Hydra from hydra._internal.utils import create_config_search_path from hydra.core.config_store import ConfigStore from hydra.core.global_hydra import GlobalHydra from hydra.core.utils import JobReturn, run_job -from hydra.experimental import compose, initialize from hydra.plugins.sweeper import Sweeper +from hydra.types import HydraContext from omegaconf import DictConfig from .._hydra_overloads import instantiate @@ -19,7 +21,7 @@ def _store_config( cfg: Union[DataClass, DictConfig, Mapping], config_name: str = "hydra_launch" ) -> str: - """Generates a Structured Config and registers it in the ConfigStore. + """Stores configuration object in Hydra's ConfigStore. Parameters ---------- @@ -27,7 +29,7 @@ def _store_config( A configuration as a dataclass, configuration object, or a dictionary. config_name: str (default: hydra_launch) - A default configuration name if available, otherwise a new object is + The configuration name used to store the configuration. Returns ------- @@ -64,6 +66,7 @@ def _load_config( Returns ------- config: DictConfig + The configuration object including Hydra configuration. Notes ----- @@ -71,12 +74,12 @@ def _load_config( References ---------- - .. [1] https://hydra.cc/docs/experimental/compose_api - .. [2] https://hydra.cc/docs/configure_hydra/intro - .. [3] https://hydra.cc/docs/advanced/override_grammar/basic + .. [1] https://hydra.cc/docs/next/advanced/compose_api + .. [2] https://hydra.cc/docs/next/configure_hydra/intro + .. [3] https://hydra.cc/docs/next/advanced/override_grammar/basic """ - with initialize(): + with initialize(config_path=None): task_cfg = compose( config_name, overrides=[] if overrides is None else overrides, @@ -90,7 +93,9 @@ def hydra_run( config: Union[DataClass, DictConfig, Mapping], task_function: Callable[[DictConfig], Any], overrides: Optional[List[str]] = None, + config_dir: Optional[Union[str, Path]] = None, config_name: str = "hydra_run", + job_name: str = "hydra_run", ) -> JobReturn: """Launch a Hydra job defined by `task_function` using the configuration provided in `config`. @@ -121,6 +126,9 @@ def hydra_run( config_dir: Optional[Union[str, Path]] (default: None) Add configuration directories if needed. + config_name: str (default: "hydra_run") + Name of the stored configuration in Hydra's ConfigStore API. + job_name: str (default: "hydra_run") Returns @@ -136,8 +144,8 @@ def hydra_run( References ---------- - .. [1] https://hydra.cc/docs/advanced/override_grammar/basic - .. [2] https://hydra.cc/docs/configure_hydra/intro + .. [1] https://hydra.cc/docs/next/advanced/override_grammar/basic + .. [2] https://hydra.cc/docs/next/configure_hydra/intro Examples -------- @@ -186,14 +194,28 @@ def hydra_run( config_name = _store_config(config, config_name) task_cfg = _load_config(config_name=config_name, overrides=overrides) + if config_dir is not None: + config_dir = str(Path(config_dir).absolute()) + search_path = create_config_search_path(config_dir) + + hydra = Hydra.create_main_hydra2(task_name=job_name, config_search_path=search_path) + try: + callbacks = Callbacks(task_cfg) + callbacks.on_run_start(config=task_cfg, config_name=config_name) + job = run_job( + hydra_context=HydraContext( + config_loader=hydra.config_loader, callbacks=callbacks + ), config=task_cfg, task_function=task_function, job_dir_key="hydra.run.dir", job_subdir_key=None, configure_logging=False, ) + + callbacks.on_run_end(config=task_cfg, config_name=config_name) finally: GlobalHydra.instance().clear() return job @@ -206,7 +228,7 @@ def hydra_multirun( config_dir: Optional[Union[str, Path]] = None, config_name: str = "hydra_multirun", job_name: str = "hydra_multirun", -) -> List[JobReturn]: +) -> List[Any]: """Launch a Hydra multi-run ([1]_) job defined by `task_function` using the configuration provided in `config`. @@ -244,24 +266,21 @@ def hydra_multirun( config_dir: Optional[Union[str, Path]] (default: None) Add configuration directories if needed. + config_name: str (default: "hydra_run") + Name of the stored configuration in Hydra's ConfigStore API. + job_name: str (default: "hydra_multirun") Returns ------- - result: List[List[JobReturn]] - The object storing the results of each Hydra experiment. - - overrides: From `overrides` and `multirun_overrides` - - return_value: The return value of the task function - - cfg: The configuration object sent to the task function - - hydra_cfg: The hydra configuration object - - working_dir: The experiment working directory - - task_name: The task name of the Hydra job + result: List[List[Any]] + The return values of all launched jobs (depends on the Sweeper implementation). References ---------- .. [1] https://hydra.cc/docs/tutorials/basic/running_your_app/multi-run - .. [2] https://hydra.cc/docs/advanced/override_grammar/basic - .. [3] https://hydra.cc/docs/configure_hydra/intro + .. [2] https://hydra.cc/docs/next/advanced/override_grammar/basic + .. [3] https://hydra.cc/docs/next/configure_hydra/intro Examples -------- @@ -315,15 +334,16 @@ def hydra_multirun( # Separate Hydra overrides from experiment overrides hydra_overrides = [] _overrides = [] - for o in overrides: - if o.startswith("hydra"): - hydra_overrides.append(o) - else: - _overrides.append(o) + if overrides is not None: + for o in overrides: + if o.startswith("hydra"): + hydra_overrides.append(o) + else: + _overrides.append(o) # Only the hydra overrides are needed to extract the Hydra configuration for # the launcher and sweepers. - # The sweeper handles the overides for each experiment + # The sweeper handles the overrides for each experiment config_name = _store_config(config, config_name) task_cfg = _load_config(config_name=config_name, overrides=hydra_overrides) @@ -333,12 +353,15 @@ def hydra_multirun( hydra = Hydra.create_main_hydra2(task_name=job_name, config_search_path=search_path) try: + callbacks = Callbacks(task_cfg) + callbacks.on_multirun_start(config=task_cfg, config_name=config_name) + # Instantiate sweeper without using Hydra's Plugin discovery sweeper = instantiate(task_cfg.hydra.sweeper) assert isinstance(sweeper, Sweeper) sweeper.setup( config=task_cfg, - config_loader=hydra.config_loader, + hydra_context=HydraContext(hydra.config_loader, callbacks=callbacks), task_function=task_function, ) @@ -348,6 +371,7 @@ def hydra_multirun( # just ensures repeats are removed _overrides = list(set(_overrides)) job = sweeper.sweep(arguments=_overrides) + callbacks.on_multirun_end(config=task_cfg, config_name=config_name) finally: GlobalHydra.instance().clear() return job diff --git a/src/hydra_zen/structured_configs/_utils.py b/src/hydra_zen/structured_configs/_utils.py index e91dc915f..204dfd492 100644 --- a/src/hydra_zen/structured_configs/_utils.py +++ b/src/hydra_zen/structured_configs/_utils.py @@ -142,38 +142,6 @@ def get_obj_path(obj: Any) -> str: return f"{module}.{name}" -def interpolated(func: Union[str, Callable], *literals: Any) -> str: - """Produces an hydra-style interpolated string for calling the provided - function on the literals - - Parameters - ---------- - func : Union[str, Callable] - The name of the function to use in the interpolation. The name - will be inferred if a function is provided. - - literals : Any - Position-only literals to be fed to the function. - - Notes - ----- - See https://omegaconf.readthedocs.io/en/latest/usage.html#custom-interpolations - for more details about leveraging custom interpolations in omegaconf/hydra. - - Examples - -------- - >>> def add(x, y): return x + y - >>> interpolated(add, 1, 2) - '${add:1,2}' - """ - if not isinstance(func, str) and not hasattr(func, "__name__"): # pragma: no cover - raise TypeError( - f"`func` must be a string or have a `__name__` field, got: {func}" - ) - name = func if isinstance(func, str) else func.__name__ - return f"${{{name}:{','.join(repr(i) for i in literals)}}}" - - NoneType = type(None) diff --git a/tests/experimental/test_callbacks.py b/tests/experimental/test_callbacks.py new file mode 100644 index 000000000..e75c68fcb --- /dev/null +++ b/tests/experimental/test_callbacks.py @@ -0,0 +1,115 @@ +from typing import NamedTuple + +import pytest +from hydra.core.config_store import ConfigStore +from hydra.core.utils import JobReturn +from hydra.experimental.callback import Callback +from omegaconf import DictConfig + +from hydra_zen import builds, instantiate +from hydra_zen.experimental import hydra_multirun, hydra_run + + +class Tracker(NamedTuple): + job_start: bool = False + job_end: bool = False + run_start: bool = False + run_end: bool = False + multirun_start: bool = False + multirun_end: bool = False + + +class CustomCallback(Callback): + JOB_START_CALLED = False + JOB_END_CALLED = False + + RUN_START_CALLED = False + RUN_END_CALLED = False + + MULTIRUN_START_CALLED = False + MULTIRUN_END_CALLED = False + + def __init__(self, callback_name): + self.name = callback_name + + def on_job_start(self, config: DictConfig, **kwargs) -> None: + CustomCallback.JOB_START_CALLED = True + + def on_job_end(self, config: DictConfig, job_return: JobReturn, **kwargs) -> None: + CustomCallback.JOB_END_CALLED = True + + def on_run_start(self, config: DictConfig, **kwargs) -> None: + CustomCallback.RUN_START_CALLED = True + + def on_run_end(self, config: DictConfig, **kwargs) -> None: + CustomCallback.RUN_END_CALLED = True + + def on_multirun_start(self, config: DictConfig, **kwargs) -> None: + CustomCallback.MULTIRUN_START_CALLED = True + + def on_multirun_end(self, config: DictConfig, **kwargs) -> None: + CustomCallback.MULTIRUN_END_CALLED = True + + +cs = ConfigStore.instance() +cs.store( + group="hydra/callbacks", + name="test_callback", + node=dict(test_callback=builds(CustomCallback, callback_name="test")), +) + + +def tracker(x=CustomCallback): + # this will get called after the job and run have started + # but before they end + return Tracker( + job_start=x.JOB_START_CALLED, + job_end=x.JOB_END_CALLED, + run_start=x.RUN_START_CALLED, + run_end=x.RUN_END_CALLED, + multirun_start=x.MULTIRUN_START_CALLED, + multirun_end=x.MULTIRUN_END_CALLED, + ) + + +@pytest.mark.usefixtures("cleandir") +@pytest.mark.parametrize("fn", [hydra_run, hydra_multirun]) +def test_hydra_run_with_callback(fn): + # Tests that callback methods are called during appropriate + # stages + try: + is_multirun = fn is hydra_multirun + + cfg = builds(tracker) + + assert not any(tracker()) # ensures all flags are false + + job = fn( + cfg, task_function=instantiate, overrides=["hydra/callbacks=test_callback"] + ) + + if is_multirun: + job = job[0][0] + + tracked_mid_run: Tracker = job.return_value + assert tracked_mid_run.job_start is True + assert tracked_mid_run.run_start is not is_multirun + assert tracked_mid_run.multirun_start is is_multirun + + assert tracked_mid_run.job_end is False + assert tracked_mid_run.run_end is False + assert tracked_mid_run.multirun_end is False + + assert CustomCallback.JOB_END_CALLED is True + assert CustomCallback.RUN_END_CALLED is not is_multirun + assert CustomCallback.MULTIRUN_END_CALLED is is_multirun + + finally: + CustomCallback.JOB_START_CALLED = False + CustomCallback.JOB_END_CALLED = False + + CustomCallback.RUN_START_CALLED = False + CustomCallback.RUN_END_CALLED = False + + CustomCallback.MULTIRUN_START_CALLED = False + CustomCallback.MULTIRUN_END_CALLED = False diff --git a/tests/experimental/test_implementations.py b/tests/experimental/test_implementations.py index ac19a6705..bbc690001 100644 --- a/tests/experimental/test_implementations.py +++ b/tests/experimental/test_implementations.py @@ -39,7 +39,10 @@ def test_store_config(as_dataclass, as_dictconfig): @pytest.mark.parametrize( "as_dictconfig, with_hydra", [(True, True), (True, False), (False, False)] ) -def test_hydra_run_job(overrides, as_dataclass, as_dictconfig, with_hydra): +@pytest.mark.parametrize("use_default_dir", [True, False]) +def test_hydra_run_job( + overrides, as_dataclass, as_dictconfig, with_hydra, use_default_dir +): if not as_dataclass: cfg = dict(a=1, b=1) else: @@ -55,7 +58,10 @@ def test_hydra_run_job(overrides, as_dataclass, as_dictconfig, with_hydra): cfg = _load_config(cn, overrides=overrides) overrides = [] - job = hydra_run(cfg, task_function=instantiate, overrides=overrides) + additl_kwargs = {} if use_default_dir else dict(config_dir=Path.cwd()) + job = hydra_run( + cfg, task_function=instantiate, overrides=overrides, **additl_kwargs + ) assert job.return_value == {"a": 1, "b": 1} if override_exists == 1: @@ -67,19 +73,27 @@ def test_hydra_run_job(overrides, as_dataclass, as_dictconfig, with_hydra): "overrides", [None, [], ["hydra.sweep.dir=test_hydra_overrided"]], ) +@pytest.mark.parametrize( + "multirun_overrides", + [None, ["a=1,2"]], +) @pytest.mark.parametrize("as_dataclass", [True, False]) @pytest.mark.parametrize( "as_dictconfig, with_hydra", [(True, True), (True, False), (False, False)] ) @pytest.mark.parametrize("use_default_dir", [True, False]) def test_hydra_multirun( - overrides, as_dataclass, as_dictconfig, with_hydra, use_default_dir + overrides, + multirun_overrides, + as_dataclass, + as_dictconfig, + with_hydra, + use_default_dir, ): if not as_dataclass: cfg = dict(a=1, b=1) else: cfg = builds(dict, a=1, b=1) - multirun_overrides = ["a=1,2"] override_exists = overrides and len(overrides) > 1 if as_dictconfig: @@ -91,9 +105,15 @@ def test_hydra_multirun( overrides = [] additl_kwargs = {} if use_default_dir else dict(config_dir=Path.cwd()) - _overrides = ( - multirun_overrides if overrides is None else (overrides + multirun_overrides) - ) + + _overrides = overrides + if multirun_overrides is not None: + _overrides = ( + multirun_overrides + if overrides is None + else (overrides + multirun_overrides) + ) + job = hydra_multirun( cfg, task_function=instantiate, overrides=_overrides, **additl_kwargs ) diff --git a/tests/test_utils.py b/tests/test_utils.py index ea5c16a3a..29c8de0bb 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -20,11 +20,9 @@ from typing_extensions import Final, Literal from hydra_zen import builds, instantiate, mutable_value -from hydra_zen.structured_configs._utils import interpolated, safe_name, sanitized_type +from hydra_zen.structured_configs._utils import safe_name, sanitized_type from hydra_zen.typing import Builds -from . import valid_hydra_literals - T = TypeVar("T") current_module: str = sys.modules[__name__].__name__ @@ -38,31 +36,6 @@ def pass_through_kwargs(**kwargs): return kwargs -omegaconf.OmegaConf.register_new_resolver("_test_pass_through", pass_through) - - -@given(st.lists(valid_hydra_literals)) -def test_interpolate_roundtrip(literals): - interpolated_string = interpolated("_test_pass_through", *literals) - - note(interpolated_string) - - interpolated_literals = OmegaConf.create({"x": interpolated_string}).x - - assert len(literals) == len(interpolated_literals) - - for lit, interp in zip(literals, interpolated_literals): - assert lit == interp - - -OmegaConf.register_new_resolver("len", len) - - -def test_interpolation_with_string_literal(): - cc = instantiate(builds(dict, total=interpolated(len, "9"))) - assert cc["total"] == 1 - - class C: def __repr__(self): return "C as a repr"