From f3c75b8aae841843b322c51647313a09ffbdd9a7 Mon Sep 17 00:00:00 2001 From: Luca Soldaini Date: Fri, 2 Dec 2022 21:41:21 -0800 Subject: [PATCH] bugfix for flexyclasses --- pyproject.toml | 2 +- src/springs/commandline.py | 12 +++++++++++- src/springs/flexyclasses.py | 13 +++++++++++++ src/springs/nicknames.py | 13 ++++++++++--- 4 files changed, 35 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 50459be..0dd34ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "springs" -version = "1.9.0" +version = "1.9.1" description = "A set of utilities to create and manage typed configuration files effectively, built on top of OmegaConf." authors = [ {name = "Luca Soldaini", email = "luca@soldaini.net" } diff --git a/src/springs/commandline.py b/src/springs/commandline.py index 697b1a0..b0a10e5 100644 --- a/src/springs/commandline.py +++ b/src/springs/commandline.py @@ -29,6 +29,7 @@ to_yaml, unsafe_merge, ) +from .flexyclasses import is_flexyclass from .logging import configure_logging from .nicknames import NicknameRegistry from .rich_utils import ( @@ -260,8 +261,17 @@ def load_from_file_or_nickname( loaded_config = NicknameRegistry.get( name=config_path_or_nickname, raise_if_missing=True ) - if not isinstance(loaded_config, (DictConfig, ListConfig)): + + if is_dataclass(loaded_config): loaded_config = from_dataclass(loaded_config) + elif is_flexyclass(loaded_config): + loaded_config = loaded_config.to_dict_config() # type: ignore + elif not isinstance(loaded_config, (DictConfig, ListConfig)): + raise ValueError( + f"Nickname '{config_path_or_nickname}' is not a " + "DictConfig or ListConfig." + ) + else: # config file is to load from file loaded_config = from_file(config_path_or_nickname) diff --git a/src/springs/flexyclasses.py b/src/springs/flexyclasses.py index 0d349fc..2969ef4 100644 --- a/src/springs/flexyclasses.py +++ b/src/springs/flexyclasses.py @@ -4,6 +4,7 @@ from typing import Any, Dict, Generic, Type, TypeVar from omegaconf import MISSING as OC_MISSING +from omegaconf import DictConfig from typing_extensions import dataclass_transform from .utils import get_annotations @@ -61,6 +62,13 @@ def __new__(cls, **kwargs): factory_dict = {**cls.defaults(), **kwargs} return field(default_factory=lambda: factory_dict) + @classmethod + def to_dict_config(cls, **kwargs: Any) -> DictConfig: + """Convert the FlexyClass to an OmegaConf DictConfig object""" + from .core import from_dict + + return from_dict({**cls.defaults(), **kwargs}) + @classmethod def flexyclass(cls, target_cls: Type[C]) -> Type["FlexyClass"]: """Decorator to create a FlexyClass from a class""" @@ -89,3 +97,8 @@ def flexyclass(cls, target_cls: Type[C]) -> Type["FlexyClass"]: def flexyclass(cls: Type[C]) -> Type[FlexyClass[C]]: """Alias for FlexyClass.flexyclass""" return FlexyClass.flexyclass(cls) + + +def is_flexyclass(obj: Any) -> bool: + """Check if an object is a FlexyClass""" + return isinstance(obj, FlexyClass) diff --git a/src/springs/nicknames.py b/src/springs/nicknames.py index 870bcd1..3dca2e9 100644 --- a/src/springs/nicknames.py +++ b/src/springs/nicknames.py @@ -1,4 +1,5 @@ from dataclasses import is_dataclass +from inspect import isclass from pathlib import Path from typing import ( Any, @@ -12,15 +13,17 @@ Type, TypeVar, Union, + cast, overload, ) from omegaconf import DictConfig, ListConfig from .core import from_file +from .flexyclasses import FlexyClass from .logging import configure_logging -RegistryValue = Union[Type[Any], DictConfig, ListConfig] +RegistryValue = Union[Type[Any], Type[FlexyClass], DictConfig, ListConfig] T = TypeVar("T") M = TypeVar("M", bound=RegistryValue) @@ -92,12 +95,16 @@ def add(cls, name: str) -> Callable[[Type[T]], Type[T]]: for easy reuse.""" def add_to_registry(cls_: Type[T]) -> Type[T]: - if not is_dataclass(cls_): + if not ( + is_dataclass(cls_) + or isclass(cls_) + and issubclass(cls_, FlexyClass) + ): raise ValueError(f"{cls_} must be a dataclass") if name in cls.__registry__: raise ValueError(f"{name} is already registered") - return cls._add(name, cls_) + return cast(Type[T], cls._add(name, cls_)) return add_to_registry