Skip to content

Commit

Permalink
bugfix for flexyclasses
Browse files Browse the repository at this point in the history
  • Loading branch information
soldni committed Dec 3, 2022
1 parent d832e7f commit f3c75b8
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 5 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "springs"
version = "1.9.0"
version = "1.9.1"
description = "A set of utilities to create and manage typed configuration files effectively, built on top of OmegaConf."
authors = [
{name = "Luca Soldaini", email = "[email protected]" }
Expand Down
12 changes: 11 additions & 1 deletion src/springs/commandline.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
to_yaml,
unsafe_merge,
)
from .flexyclasses import is_flexyclass
from .logging import configure_logging
from .nicknames import NicknameRegistry
from .rich_utils import (
Expand Down Expand Up @@ -260,8 +261,17 @@ def load_from_file_or_nickname(
loaded_config = NicknameRegistry.get(
name=config_path_or_nickname, raise_if_missing=True
)
if not isinstance(loaded_config, (DictConfig, ListConfig)):

if is_dataclass(loaded_config):
loaded_config = from_dataclass(loaded_config)
elif is_flexyclass(loaded_config):
loaded_config = loaded_config.to_dict_config() # type: ignore
elif not isinstance(loaded_config, (DictConfig, ListConfig)):
raise ValueError(
f"Nickname '{config_path_or_nickname}' is not a "
"DictConfig or ListConfig."
)

else:
# config file is to load from file
loaded_config = from_file(config_path_or_nickname)
Expand Down
13 changes: 13 additions & 0 deletions src/springs/flexyclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Dict, Generic, Type, TypeVar

from omegaconf import MISSING as OC_MISSING
from omegaconf import DictConfig
from typing_extensions import dataclass_transform

from .utils import get_annotations
Expand Down Expand Up @@ -61,6 +62,13 @@ def __new__(cls, **kwargs):
factory_dict = {**cls.defaults(), **kwargs}
return field(default_factory=lambda: factory_dict)

@classmethod
def to_dict_config(cls, **kwargs: Any) -> DictConfig:
"""Convert the FlexyClass to an OmegaConf DictConfig object"""
from .core import from_dict

return from_dict({**cls.defaults(), **kwargs})

@classmethod
def flexyclass(cls, target_cls: Type[C]) -> Type["FlexyClass"]:
"""Decorator to create a FlexyClass from a class"""
Expand Down Expand Up @@ -89,3 +97,8 @@ def flexyclass(cls, target_cls: Type[C]) -> Type["FlexyClass"]:
def flexyclass(cls: Type[C]) -> Type[FlexyClass[C]]:
"""Alias for FlexyClass.flexyclass"""
return FlexyClass.flexyclass(cls)


def is_flexyclass(obj: Any) -> bool:
"""Check if an object is a FlexyClass"""
return isinstance(obj, FlexyClass)
13 changes: 10 additions & 3 deletions src/springs/nicknames.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import is_dataclass
from inspect import isclass
from pathlib import Path
from typing import (
Any,
Expand All @@ -12,15 +13,17 @@
Type,
TypeVar,
Union,
cast,
overload,
)

from omegaconf import DictConfig, ListConfig

from .core import from_file
from .flexyclasses import FlexyClass
from .logging import configure_logging

RegistryValue = Union[Type[Any], DictConfig, ListConfig]
RegistryValue = Union[Type[Any], Type[FlexyClass], DictConfig, ListConfig]

T = TypeVar("T")
M = TypeVar("M", bound=RegistryValue)
Expand Down Expand Up @@ -92,12 +95,16 @@ def add(cls, name: str) -> Callable[[Type[T]], Type[T]]:
for easy reuse."""

def add_to_registry(cls_: Type[T]) -> Type[T]:
if not is_dataclass(cls_):
if not (
is_dataclass(cls_)
or isclass(cls_)
and issubclass(cls_, FlexyClass)
):
raise ValueError(f"{cls_} must be a dataclass")

if name in cls.__registry__:
raise ValueError(f"{name} is already registered")
return cls._add(name, cls_)
return cast(Type[T], cls._add(name, cls_))

return add_to_registry

Expand Down

0 comments on commit f3c75b8

Please sign in to comment.