Skip to content

Commit

Permalink
Merge branch 'master' into ray-launcher-v2
Browse files Browse the repository at this point in the history
  • Loading branch information
jieru-hu committed Apr 13, 2020
2 parents 775c185 + 079c325 commit ad3d76c
Show file tree
Hide file tree
Showing 66 changed files with 883 additions and 738 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ __pycache__
/.nox
/report.json
/.coverage
/.mypy_cache
.mypy_cache
pip-wheel-metadata
.ipynb_checkpoints
/.dmypy.json
2 changes: 1 addition & 1 deletion .isort.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ force_grid_wrap=0
use_parentheses=True
line_length=88
ensure_newline_before_comments=True
known_third_party=nevergrad,ax,joblib,omegaconf,ray,pytest,typing_extensions
known_third_party=omegaconf,ray,pytest,typing_extensions
known_first_party=hydra,hydra_plugins
8 changes: 0 additions & 8 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,3 @@ repos:
hooks:
- id: flake8
additional_dependencies: [-e, "git+git://github.com/pycqa/pyflakes.git@1911c20#egg=pyflakes"]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.761
hooks:
- id: mypy
args: [--strict]
additional_dependencies: ["omegaconf==2.0.0rc18"]
exclude: setup.py
2 changes: 1 addition & 1 deletion examples/advanced/hydra_app_example/hydra_app/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from omegaconf import DictConfig

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import hydra


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

@hydra.main()
def experiment(_cfg: DictConfig) -> None:
print(HydraConfig.instance().hydra.job.name)
print(HydraConfig.get().job.name)


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

@hydra.main(config_path="config.yaml")
def experiment(_cfg: DictConfig) -> None:
print(HydraConfig.instance().hydra.job.name)
print(HydraConfig.get().job.name)


if __name__ == "__main__":
Expand Down
5 changes: 2 additions & 3 deletions hydra/_internal/core_plugins/basic_sweeper.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,15 @@

from omegaconf import DictConfig

from hydra.conf import PluginConf
from hydra.core.config_loader import ConfigLoader
from hydra.core.config_store import ConfigStore
from hydra.core.utils import JobReturn
from hydra.plugins.sweeper import Sweeper
from hydra.types import TaskFunction
from hydra.types import ObjectConf, TaskFunction


@dataclass
class BasicSweeperConf(PluginConf):
class BasicSweeperConf(ObjectConf):
cls: str = "hydra._internal.core_plugins.basic_sweeper.BasicSweeper"

@dataclass
Expand Down
122 changes: 120 additions & 2 deletions hydra/_internal/utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import argparse
import copy
import inspect
import logging.config
import os
import sys
import warnings
from os.path import dirname, join, normpath, realpath
from typing import Any, List, Optional, Sequence, Tuple
from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, Union

from omegaconf import DictConfig, OmegaConf, _utils, read_write

from hydra._internal.config_search_path_impl import ConfigSearchPathImpl
from hydra.core.config_search_path import ConfigSearchPath
from hydra.core.utils import get_valid_filename, split_config_path
from hydra.types import TaskFunction
from hydra.types import ObjectConf, TaskFunction

log = logging.getLogger(__name__)


def detect_calling_file_or_module(
Expand Down Expand Up @@ -302,3 +309,114 @@ def get_column_widths(matrix: List[List[str]]) -> List[int]:
widths[idx] = max(widths[idx], len(col))

return widths


def _instantiate_class(
clazz: Type[Any], config: Union[ObjectConf, DictConfig], *args: Any, **kwargs: Any
) -> Any:
final_kwargs = _get_kwargs(config, **kwargs)
return clazz(*args, **final_kwargs)


def _call_callable(
fn: Callable[..., Any],
config: Union[ObjectConf, DictConfig],
*args: Any,
**kwargs: Any,
) -> Any:
final_kwargs = _get_kwargs(config, **kwargs)
return fn(*args, **final_kwargs)


def _locate(path: str) -> Union[type, Callable[..., Any]]:
"""
Locate an object by name or dotted path, importing as necessary.
This is similar to the pydoc function `locate`, except that it checks for
the module from the given path from back to front.
"""
import builtins
from importlib import import_module

parts = [part for part in path.split(".") if part]
module = None
for n in reversed(range(len(parts))):
try:
module = import_module(".".join(parts[:n]))
except Exception as e:
if n == 0:
log.error(f"Error loading module {path} : {e}")
raise e
continue
if module:
break
if module:
obj = module
else:
obj = builtins
for part in parts[n:]:
if not hasattr(obj, part):
raise ValueError(
f"Error finding attribute ({part}) in class ({obj.__name__}): {path}"
)
obj = getattr(obj, part)
if isinstance(obj, type):
obj_type: type = obj
return obj_type
elif callable(obj):
obj_callable: Callable[..., Any] = obj
return obj_callable
else:
# dummy case
raise ValueError(f"Invalid type ({type(obj)}) found for {path}")


def _get_kwargs(config: Union[ObjectConf, DictConfig], **kwargs: Any) -> Any:
# copy config to avoid mutating it when merging with kwargs
config_copy = copy.deepcopy(config)

# Manually set parent as deepcopy does not currently handles it (https://github.com/omry/omegaconf/issues/130)
# noinspection PyProtectedMember
config_copy._set_parent(config._get_parent()) # type: ignore
config = config_copy

params = config.params if "params" in config else OmegaConf.create()
assert isinstance(
params, DictConfig
), f"Input config params are expected to be a mapping, found {type(config.params).__name__}"
primitives = {}
rest = {}
for k, v in kwargs.items():
if _utils.is_primitive_type(v) or isinstance(v, (dict, list)):
primitives[k] = v
else:
rest[k] = v
final_kwargs = {}
with read_write(params):
params.merge_with(OmegaConf.create(primitives))

for k, v in params.items():
final_kwargs[k] = v

for k, v in rest.items():
final_kwargs[k] = v
return final_kwargs


def _get_cls_name(config: Union[ObjectConf, DictConfig]) -> str:
if "class" in config:
warnings.warn(
"\n"
"ObjectConf field 'class' is deprecated since Hydra 1.0.0 and will be removed in a future Hydra version.\n"
"Offending config class:\n"
f"\tclass={config['class']}\n"
"Change your config to use 'cls' instead of 'class'.\n",
category=UserWarning,
)
classname = config["class"]
assert isinstance(classname, str)
return classname
else:
if "cls" in config:
return config.cls
else:
raise ValueError("Input config does not have a cls field")
16 changes: 4 additions & 12 deletions hydra/conf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
from typing import Any, List, Optional

from omegaconf import MISSING

from hydra.core.config_store import ConfigStore
from hydra.types import ObjectConf

hydra_defaults = [
# Hydra's logging config
Expand All @@ -24,15 +25,6 @@
]


@dataclass
# This extends Dict[str, Any] to allow for the deprecated "class" field.
# Once support for class field removed this can stop extending Dict.
class PluginConf(Dict[str, Any]):
# class name for plugin
cls: str = MISSING
params: Any = field(default_factory=dict)


@dataclass
class HelpConf:
app_name: str = MISSING
Expand Down Expand Up @@ -119,9 +111,9 @@ class HydraConf:
job_logging: Any = MISSING

# Sweeper configuration
sweeper: PluginConf = field(default_factory=PluginConf)
sweeper: ObjectConf = field(default_factory=ObjectConf)
# Launcher configuration
launcher: PluginConf = field(default_factory=PluginConf)
launcher: ObjectConf = field(default_factory=ObjectConf)

# Program Help template
help: HelpConf = HelpConf()
Expand Down
26 changes: 18 additions & 8 deletions hydra/core/hydra_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import copy
from typing import Any
from typing import Any, Optional

from omegaconf import DictConfig, OmegaConf

Expand All @@ -9,15 +8,26 @@


class HydraConfig(metaclass=Singleton):
hydra: HydraConf

def __init__(self) -> None:
ret = OmegaConf.structured(HydraConf)
self.hydra = ret
self.cfg: Optional[HydraConf] = None

def set_config(self, cfg: DictConfig) -> None:
self.hydra = copy.deepcopy(cfg.hydra)
OmegaConf.set_readonly(self.hydra, True) # type: ignore
assert cfg is not None
OmegaConf.set_readonly(cfg.hydra, True)
assert OmegaConf.get_type(cfg, "hydra") == HydraConf
self.cfg = cfg # type: ignore

@staticmethod
def get() -> HydraConf:
instance = HydraConfig.instance()
if instance.cfg is None:
raise ValueError("HydraConfig was not set")
return instance.cfg.hydra # type: ignore

@staticmethod
def initialized() -> bool:
instance = HydraConfig.instance()
return instance.cfg is not None

@staticmethod
def instance(*args: Any, **kwargs: Any) -> "HydraConfig":
Expand Down
11 changes: 5 additions & 6 deletions hydra/core/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from omegaconf import DictConfig

from hydra._internal.sources_registry import SourcesRegistry
from hydra.conf import PluginConf
from hydra.core.config_loader import ConfigLoader
from hydra.core.singleton import Singleton
from hydra.plugins.completion_plugin import CompletionPlugin
Expand All @@ -20,7 +19,7 @@
from hydra.plugins.plugin import Plugin
from hydra.plugins.search_path_plugin import SearchPathPlugin
from hydra.plugins.sweeper import Sweeper
from hydra.types import TaskFunction
from hydra.types import ObjectConf, TaskFunction


@dataclass
Expand Down Expand Up @@ -68,10 +67,10 @@ def _initialize(self) -> None:
assert issubclass(source, ConfigSource)
SourcesRegistry.instance().register(source)

def _instantiate(self, config: PluginConf) -> Plugin:
import hydra.utils as utils
def _instantiate(self, config: ObjectConf) -> Plugin:
import hydra._internal.utils as internal_utils

classname = utils._get_cls_name(config)
classname = internal_utils._get_cls_name(config)
try:
if classname is None:
raise ImportError("class not configured")
Expand All @@ -86,7 +85,7 @@ def _instantiate(self, config: PluginConf) -> Plugin:
if classname not in self.class_name_to_class.keys():
raise RuntimeError(f"Unknown plugin class : '{classname}'")
clazz = self.class_name_to_class[classname]
plugin = utils._instantiate_class(clazz=clazz, config=config)
plugin = internal_utils._instantiate_class(clazz=clazz, config=config)
assert isinstance(plugin, Plugin)

except ImportError as e:
Expand Down
2 changes: 1 addition & 1 deletion hydra/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def run_job(
task_cfg = copy.deepcopy(config)
del task_cfg["hydra"]
ret.cfg = task_cfg
ret.hydra_cfg = OmegaConf.create({"hydra": HydraConfig.instance().hydra})
ret.hydra_cfg = OmegaConf.create({"hydra": HydraConfig.get()})
overrides = OmegaConf.to_container(config.hydra.overrides.task)
assert isinstance(overrides, list)
ret.overrides = overrides
Expand Down
Loading

0 comments on commit ad3d76c

Please sign in to comment.