Skip to content

Commit

Permalink
Add a new hydra util (#256)
Browse files Browse the repository at this point in the history
  • Loading branch information
tmke8 authored Sep 26, 2023
2 parents 16bb2ab + a9a556b commit cf694b9
Show file tree
Hide file tree
Showing 5 changed files with 336 additions and 8 deletions.
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
"pl.LightningModule": "pytorch_lightning.LightningModule",
}
autodoc_mock_imports = [
"attr",
"attrs",
"hydra",
"loguru",
"numpy",
Expand Down
24 changes: 21 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,13 @@ loguru = { version = "^0.6.0", optional = true }
# hydra dependencies
hydra-core = { version = "^1.3.0", optional = true }
neoconfigen = { version = ">=2.3.3", optional = true }
attrs = { version = "^23.1.0", optional = true }

[tool.poetry.extras]
wandb = ["pandas", "wandb"]
logging = ["loguru"]
hydra = ["hydra-core", "neoconfigen"]
all = ["hydra-core", "loguru", "neoconfigen", "pandas", "wandb"]
hydra = ["attrs", "hydra-core", "neoconfigen"]
all = ["attrs", "hydra-core", "loguru", "neoconfigen", "pandas", "wandb"]

[tool.poetry.group.format.dependencies]
black = "^23.1"
Expand Down
106 changes: 104 additions & 2 deletions ranzen/hydra/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@
from __future__ import annotations
from collections.abc import MutableMapping
from contextlib import contextmanager
from dataclasses import asdict
import dataclasses
from dataclasses import MISSING, Field, asdict, is_dataclass
from enum import Enum
import shlex
from typing import Any, Iterator, Sequence
from typing import Any, Dict, Final, Iterator, Sequence, Union, cast
from typing_extensions import deprecated

import attrs
from attrs import NOTHING, Attribute
from hydra.core.config_store import ConfigStore
from hydra.core.hydra_config import HydraConfig
from hydra.utils import instantiate
Expand All @@ -20,8 +24,13 @@
"as_pretty_dict",
"reconstruct_cmd",
"recursively_instantiate",
"prepare_for_logging",
"register_hydra_config",
]

NEED: Final = "there should be"
IF: Final = "if an entry has"


def _clean_up_dict(obj: Any) -> Any:
"""Convert enums to strings and filter out _target_."""
Expand All @@ -47,6 +56,7 @@ def reconstruct_cmd() -> str:
return shlex.join([program] + OmegaConf.to_container(args)) # type: ignore[operator]


@deprecated("Use _recursive_=True instead.")
def recursively_instantiate(
hydra_config: DictConfig, *, keys_to_exclude: Sequence[str] = ()
) -> dict[str, Any]:
Expand Down Expand Up @@ -101,3 +111,95 @@ def __init__(self, cs: ConfigStore, *, group_name: str, package: str):
def add_option(self, config_class: type, *, name: str) -> None:
"""Register a schema as an option for this group."""
self._cs.store(group=self._group_name, name=name, node=config_class, package=self._package)


def prepare_for_logging(hydra_config: DictConfig, *, enum_to_str: bool = True) -> dict[str, Any]:
"""Takes a hydra config dict and makes it prettier for logging.
Things this function does: turn enums to strings, resolve any references, mark entries with
their type.
"""
raw_config = OmegaConf.to_container(
hydra_config, throw_on_missing=True, enum_to_str=enum_to_str, resolve=True
)
assert isinstance(raw_config, dict)
raw_config = cast(Dict[str, Any], raw_config)
return {
f"{key}/{OmegaConf.get_type(dict_).__name__}" # type: ignore
if isinstance(dict_ := hydra_config[key], DictConfig)
else key: value
for key, value in raw_config.items()
}


def register_hydra_config(main_cls: type, groups: dict[str, dict[str, type]]) -> None:
"""Check the given config and store everything in the ConfigStore.
This function performs two tasks: 1) make the necessary calls to `ConfigStore`
and 2) run some checks over the given config and if there are problems, try to give a nice
error message.
:param main_cls: The main config class; can be dataclass or attrs.
:param groups: A dictionary that defines all the variants. The keys of top level of the
dictionary should corresponds to the group names, and the keys in the nested dictionaries
should correspond to the names of the options.
:raises ValueError: If the config is malformed in some way.
:raises RuntimeError: If hydra itself is throwing an error.
:example:
.. code-block:: python
@dataclass
class DataModule:
root: Path = Path()
@dataclass
class LinearModel:
dim: int = 256
@dataclass
class CNNModel:
kernel: int = 3
@dataclass
class Config:
dm: DataModule = dataclasses.field(default_factory=DataModule)
model: Any
groups = {"model": {"linear": LinearModel, "cnn": CNNModel}}
register_hydra_config(Config, groups)
"""
configs: Union[tuple[Attribute, ...], tuple[Field, ...]]
is_dc = is_dataclass(main_cls)
if is_dc:
configs = dataclasses.fields(main_cls)
elif attrs.has(main_cls):
configs = attrs.fields(main_cls)
else:
raise ValueError("The given class is neither a dataclass nor an attrs class.")
ABSENT = MISSING if is_dc else NOTHING

for config in configs:
if config.type == Any or (isinstance(typ := config.type, str) and typ == "Any"):
if config.name not in groups:
raise ValueError(f"{IF} type Any, {NEED} variants: `{config.name}`")
if config.default is not ABSENT or (
isinstance(config, Field) and config.default_factory is not ABSENT
):
raise ValueError(f"{IF} type Any, {NEED} no default value: `{config.name}`")
else:
if config.name in groups:
raise ValueError(f"{IF} a real type, {NEED} no variants: `{config.name}`")
if config.default is ABSENT and not (
isinstance(config, Field) and config.default_factory is not ABSENT
):
raise ValueError(f"{IF} a real type, {NEED} a default value: `{config.name}`")

cs = ConfigStore.instance()
cs.store(node=main_cls, name="config_schema")
for group, entries in groups.items():
for name, node in entries.items():
try:
cs.store(node=node, name=name, group=group)
except Exception as exc:
raise RuntimeError(f"{main_cls=}, {node=}, {name=}, {group=}") from exc
Loading

0 comments on commit cf694b9

Please sign in to comment.