Skip to content

Commit

Permalink
Merge pull request #22 from bnewm0609/main
Browse files Browse the repository at this point in the history
Upgrade to Python 3.11
  • Loading branch information
soldni authored Aug 2, 2023
2 parents 882cca4 + 5ed3f54 commit 24e58dd
Show file tree
Hide file tree
Showing 14 changed files with 116 additions and 105 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "springs"
version = "1.12.3"
version = "1.13.0"
description = """\
A set of utilities to create and manage typed configuration files \
effectively, built on top of OmegaConf.\
Expand Down
5 changes: 0 additions & 5 deletions src/springs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@
debug_logger,
fdict,
flist,
fobj,
fval,
get_nickname,
make_flexy,
make_target,
Expand All @@ -49,7 +47,6 @@
__version__ = get_version()

__all__ = [
"add_help",
"all_resolvers",
"cast",
"cli",
Expand All @@ -63,8 +60,6 @@
"field",
"flexyclass",
"flist",
"fobj",
"fval",
"from_dataclass",
"from_dict",
"from_file",
Expand Down
40 changes: 25 additions & 15 deletions src/springs/commandline.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import re
import sys
from argparse import Action
Expand Down Expand Up @@ -29,6 +30,7 @@
to_yaml,
unsafe_merge,
)
from .field_utils import field
from .flexyclasses import is_flexyclass
from .logging import configure_logging
from .nicknames import NicknameRegistry
Expand All @@ -42,6 +44,7 @@

# parameters for the main function
MP = ParamSpec("MP")
NP = ParamSpec("NP")

# type for the configuration
CT = TypeVar("CT")
Expand Down Expand Up @@ -92,10 +95,14 @@ def add_argparse(self, parser: RichArgumentParser) -> Action:
def __str__(self) -> str:
return f"{self.short}/{self.long}"

@classmethod
def field(cls, *args, **kwargs) -> "Flag":
return field(default_factory=lambda: cls(*args, **kwargs))


@dataclass
class CliFlags:
config: Flag = Flag(
config: Flag = Flag.field(
name="config",
help=(
"either a path to a YAML file containing a configuration, or "
Expand All @@ -107,22 +114,22 @@ class CliFlags:
action="append",
metavar="/path/to/config.yaml",
)
options: Flag = Flag(
options: Flag = Flag.field(
name="options",
help="print all default options and CLI flags.",
action="store_true",
)
inputs: Flag = Flag(
inputs: Flag = Flag.field(
name="inputs",
help="print the input configuration.",
action="store_true",
)
parsed: Flag = Flag(
parsed: Flag = Flag.field(
name="parsed",
help="print the parsed configuration.",
action="store_true",
)
log_level: Flag = Flag(
log_level: Flag = Flag.field(
name="log-level",
help=(
"logging level to use for this program; can be one of "
Expand All @@ -131,30 +138,30 @@ class CliFlags:
default="WARNING",
choices=["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"],
)
debug: Flag = Flag(
debug: Flag = Flag.field(
name="debug",
help="enable debug mode; equivalent to '--log-level DEBUG'",
action="store_true",
)
quiet: Flag = Flag(
quiet: Flag = Flag.field(
name="quiet",
help="if provided, it does not print the configuration when running",
action="store_true",
)
resolvers: Flag = Flag(
resolvers: Flag = Flag.field(
name="resolvers",
help=(
"print all registered resolvers in OmegaConf, "
"Springs, and current codebase"
),
action="store_true",
)
nicknames: Flag = Flag(
nicknames: Flag = Flag.field(
name="nicknames",
help="print all registered nicknames in Springs",
action="store_true",
)
save: Flag = Flag(
save: Flag = Flag.field(
name="save",
help="save the configuration to a YAML file and exit",
default=None,
Expand Down Expand Up @@ -430,10 +437,8 @@ def wrap_main_method(
def cli(
config_node_cls: Optional[Type[CT]] = None,
) -> Callable[
[
# this is a main method that takes as first input a parsed config
Callable[Concatenate[CT, MP], RT]
],
# this is a main method that takes as first input a parsed config
[Callable[Concatenate[CT, MP], RT]],
# the decorated method doesn't expect the parsed config as first input,
# since that will be parsed from the command line
Callable[MP, RT],
Expand Down Expand Up @@ -487,6 +492,7 @@ def main(cfg: Config):
name = config_node_cls.__name__

def wrapper(func: Callable[Concatenate[CT, MP], RT]) -> Callable[MP, RT]:
@functools.wraps(func)
def wrapping(*args: MP.args, **kwargs: MP.kwargs) -> RT:
# I could have used a functools.partial here, but defining
# my own function instead allows me to provide nice typing
Expand All @@ -501,4 +507,8 @@ def wrapping(*args: MP.args, **kwargs: MP.kwargs) -> RT:

return wrapping

return wrapper
# TODO: figure out why mypy complains with the following error:
# Incompatible return value type (got "Callable[[Arg(Callable[[CT,
# **MP], RT], 'func')], Callable[MP, RT]]", expected
# "Callable[[Callable[[CT, **MP], RT]], Callable[MP, RT]]")
return wrapper # type: ignore
22 changes: 12 additions & 10 deletions src/springs/flexyclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,18 @@

from .utils import get_annotations

C = TypeVar("C", bound=Any)
_C = TypeVar("_C", bound=Any)


class FlexyClass(dict, Generic[C]):
class FlexyClass(dict, Generic[_C]):
"""A FlexyClass is a dictionary with some default values assigned to it
FlexyClasses are generally not used directly, but rather creating using
the `flexyclass` decorator.
NOTE: When instantiating a new FlexyClass object directly, the constructor
actually returns a `dataclasses.Field` object. This is for API consistency
with how dataclasses are used in a structured configuration. If you want to
access values in the FlexyClass directly, use FlexyClass.defaults property.
actually returns a `dict` object. This is for API consistency with how
dataclasses are used in a structured configuration. If you want to access
values in the FlexyClass directly, use FlexyClass.defaults property.
"""

__origin__: type = dict
Expand Down Expand Up @@ -60,7 +60,8 @@ def __new__(cls, **kwargs):
# to use flexyclasses in the same way they would use a dataclass.
factory_dict: Dict[str, Any] = {}
factory_dict = {**cls.defaults(), **kwargs}
return field(default_factory=lambda: factory_dict)
return factory_dict
# return field(default_factory=lambda: factory_dict)

@classmethod
def to_dict_config(cls, **kwargs: Any) -> DictConfig:
Expand All @@ -70,7 +71,7 @@ def to_dict_config(cls, **kwargs: Any) -> DictConfig:
return from_dict({**cls.defaults(), **kwargs})

@classmethod
def flexyclass(cls, target_cls: Type[C]) -> Type["FlexyClass"]:
def flexyclass(cls, target_cls: Type[_C]) -> Type["FlexyClass[_C]"]:
"""Decorator to create a FlexyClass from a class"""

if is_dataclass(target_cls):
Expand All @@ -86,15 +87,16 @@ def flexyclass(cls, target_cls: Type[C]) -> Type["FlexyClass"]:
for f_name, f_value in attributes_iterator
}

return type(
rt = type(
target_cls.__name__,
(FlexyClass,),
{"__flexyclass_defaults__": defaults},
)
return rt


@dataclass_transform()
def flexyclass(cls: Type[C]) -> Type[FlexyClass[C]]:
@dataclass_transform(field_specifiers=(Field, field))
def flexyclass(cls: Type[_C]) -> Type[FlexyClass[_C]]:
"""Alias for FlexyClass.flexyclass"""
return FlexyClass.flexyclass(cls)

Expand Down
14 changes: 12 additions & 2 deletions src/springs/rich_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,17 @@
import re
from argparse import SUPPRESS, ArgumentParser
from dataclasses import dataclass
from typing import IO, Any, Dict, Generator, List, Optional, Sequence, Union
from typing import (
IO,
Any,
Dict,
Generator,
List,
Optional,
Sequence,
Tuple,
Union,
)

from omegaconf import DictConfig, ListConfig
from rich import box
Expand Down Expand Up @@ -153,7 +163,7 @@ def format_usage(self):
for ag in self._action_groups:
for act in ag._group_actions:
if isinstance(act.metavar, str):
metavar = (act.metavar,)
metavar: Tuple[str, ...] = (act.metavar,)
elif act.metavar is None:
metavar = (act.dest.upper(),)
else:
Expand Down
30 changes: 4 additions & 26 deletions src/springs/shortcuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,37 +73,14 @@ def make_flexy(cls_: Any) -> Any:
return flexyclass(cls_)


def fval(value: T, **kwargs) -> T:
"""Shortcut for creating a Field with a default value.
Args:
value: value returned by default factory"""

return field(default=value, **kwargs)


def fobj(object: T, **kwargs) -> T:
"""Shortcut for creating a Field with a default_factory that returns
a specific object.
Args:
obj: object returned by default factory"""

def _factory_fn() -> T:
# make a copy so that the same object isn't returned
# (it's a factory, not a singleton!)
return copy.deepcopy(object)

return field(default_factory=_factory_fn, **kwargs)


def fdict(**kwargs: Any) -> Dict[str, Any]:
"""Shortcut for creating a Field with a default_factory that returns
a dictionary.
Args:
**kwargs: values for the dictionary returned by default factory"""
return fobj(kwargs)
kwargs = copy.deepcopy(kwargs)
return field(default_factory=lambda: kwargs)


def flist(*args: Any) -> List[Any]:
Expand All @@ -112,7 +89,8 @@ def flist(*args: Any) -> List[Any]:
Args:
*args: values for the list returned by default factory"""
return fobj(list(args))
l_args = list(copy.deepcopy(args))
return field(default_factory=lambda: l_args)


def debug_logger(*args: Any, **kwargs: Any) -> Logger:
Expand Down
Loading

0 comments on commit 24e58dd

Please sign in to comment.