Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade to Python 3.11 #22

Merged
merged 6 commits into from
Aug 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "springs"
version = "1.12.3"
version = "1.13.0"
description = """\
A set of utilities to create and manage typed configuration files \
effectively, built on top of OmegaConf.\
Expand Down
5 changes: 0 additions & 5 deletions src/springs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@
debug_logger,
fdict,
flist,
fobj,
fval,
get_nickname,
make_flexy,
make_target,
Expand All @@ -49,7 +47,6 @@
__version__ = get_version()

__all__ = [
"add_help",
"all_resolvers",
"cast",
"cli",
Expand All @@ -63,8 +60,6 @@
"field",
"flexyclass",
"flist",
"fobj",
"fval",
"from_dataclass",
"from_dict",
"from_file",
Expand Down
40 changes: 25 additions & 15 deletions src/springs/commandline.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import re
import sys
from argparse import Action
Expand Down Expand Up @@ -29,6 +30,7 @@
to_yaml,
unsafe_merge,
)
from .field_utils import field
from .flexyclasses import is_flexyclass
from .logging import configure_logging
from .nicknames import NicknameRegistry
Expand All @@ -42,6 +44,7 @@

# parameters for the main function
MP = ParamSpec("MP")
NP = ParamSpec("NP")

# type for the configuration
CT = TypeVar("CT")
Expand Down Expand Up @@ -92,10 +95,14 @@ def add_argparse(self, parser: RichArgumentParser) -> Action:
def __str__(self) -> str:
return f"{self.short}/{self.long}"

@classmethod
def field(cls, *args, **kwargs) -> "Flag":
return field(default_factory=lambda: cls(*args, **kwargs))


@dataclass
class CliFlags:
config: Flag = Flag(
config: Flag = Flag.field(
name="config",
help=(
"either a path to a YAML file containing a configuration, or "
Expand All @@ -107,22 +114,22 @@ class CliFlags:
action="append",
metavar="/path/to/config.yaml",
)
options: Flag = Flag(
options: Flag = Flag.field(
name="options",
help="print all default options and CLI flags.",
action="store_true",
)
inputs: Flag = Flag(
inputs: Flag = Flag.field(
name="inputs",
help="print the input configuration.",
action="store_true",
)
parsed: Flag = Flag(
parsed: Flag = Flag.field(
name="parsed",
help="print the parsed configuration.",
action="store_true",
)
log_level: Flag = Flag(
log_level: Flag = Flag.field(
name="log-level",
help=(
"logging level to use for this program; can be one of "
Expand All @@ -131,30 +138,30 @@ class CliFlags:
default="WARNING",
choices=["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"],
)
debug: Flag = Flag(
debug: Flag = Flag.field(
name="debug",
help="enable debug mode; equivalent to '--log-level DEBUG'",
action="store_true",
)
quiet: Flag = Flag(
quiet: Flag = Flag.field(
name="quiet",
help="if provided, it does not print the configuration when running",
action="store_true",
)
resolvers: Flag = Flag(
resolvers: Flag = Flag.field(
name="resolvers",
help=(
"print all registered resolvers in OmegaConf, "
"Springs, and current codebase"
),
action="store_true",
)
nicknames: Flag = Flag(
nicknames: Flag = Flag.field(
name="nicknames",
help="print all registered nicknames in Springs",
action="store_true",
)
save: Flag = Flag(
save: Flag = Flag.field(
name="save",
help="save the configuration to a YAML file and exit",
default=None,
Expand Down Expand Up @@ -430,10 +437,8 @@ def wrap_main_method(
def cli(
config_node_cls: Optional[Type[CT]] = None,
) -> Callable[
[
# this is a main method that takes as first input a parsed config
Callable[Concatenate[CT, MP], RT]
],
# this is a main method that takes as first input a parsed config
[Callable[Concatenate[CT, MP], RT]],
# the decorated method doesn't expect the parsed config as first input,
# since that will be parsed from the command line
Callable[MP, RT],
Expand Down Expand Up @@ -487,6 +492,7 @@ def main(cfg: Config):
name = config_node_cls.__name__

def wrapper(func: Callable[Concatenate[CT, MP], RT]) -> Callable[MP, RT]:
@functools.wraps(func)
def wrapping(*args: MP.args, **kwargs: MP.kwargs) -> RT:
# I could have used a functools.partial here, but defining
# my own function instead allows me to provide nice typing
Expand All @@ -501,4 +507,8 @@ def wrapping(*args: MP.args, **kwargs: MP.kwargs) -> RT:

return wrapping

return wrapper
# TODO: figure out why mypy complains with the following error:
# Incompatible return value type (got "Callable[[Arg(Callable[[CT,
# **MP], RT], 'func')], Callable[MP, RT]]", expected
# "Callable[[Callable[[CT, **MP], RT]], Callable[MP, RT]]")
return wrapper # type: ignore
22 changes: 12 additions & 10 deletions src/springs/flexyclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,18 @@

from .utils import get_annotations

C = TypeVar("C", bound=Any)
_C = TypeVar("_C", bound=Any)


class FlexyClass(dict, Generic[C]):
class FlexyClass(dict, Generic[_C]):
"""A FlexyClass is a dictionary with some default values assigned to it
FlexyClasses are generally not used directly, but rather creating using
the `flexyclass` decorator.

NOTE: When instantiating a new FlexyClass object directly, the constructor
actually returns a `dataclasses.Field` object. This is for API consistency
with how dataclasses are used in a structured configuration. If you want to
access values in the FlexyClass directly, use FlexyClass.defaults property.
actually returns a `dict` object. This is for API consistency with how
dataclasses are used in a structured configuration. If you want to access
values in the FlexyClass directly, use FlexyClass.defaults property.
"""

__origin__: type = dict
Expand Down Expand Up @@ -60,7 +60,8 @@ def __new__(cls, **kwargs):
# to use flexyclasses in the same way they would use a dataclass.
factory_dict: Dict[str, Any] = {}
factory_dict = {**cls.defaults(), **kwargs}
return field(default_factory=lambda: factory_dict)
return factory_dict
# return field(default_factory=lambda: factory_dict)

@classmethod
def to_dict_config(cls, **kwargs: Any) -> DictConfig:
Expand All @@ -70,7 +71,7 @@ def to_dict_config(cls, **kwargs: Any) -> DictConfig:
return from_dict({**cls.defaults(), **kwargs})

@classmethod
def flexyclass(cls, target_cls: Type[C]) -> Type["FlexyClass"]:
def flexyclass(cls, target_cls: Type[_C]) -> Type["FlexyClass[_C]"]:
"""Decorator to create a FlexyClass from a class"""

if is_dataclass(target_cls):
Expand All @@ -86,15 +87,16 @@ def flexyclass(cls, target_cls: Type[C]) -> Type["FlexyClass"]:
for f_name, f_value in attributes_iterator
}

return type(
rt = type(
target_cls.__name__,
(FlexyClass,),
{"__flexyclass_defaults__": defaults},
)
return rt


@dataclass_transform()
def flexyclass(cls: Type[C]) -> Type[FlexyClass[C]]:
@dataclass_transform(field_specifiers=(Field, field))
def flexyclass(cls: Type[_C]) -> Type[FlexyClass[_C]]:
"""Alias for FlexyClass.flexyclass"""
return FlexyClass.flexyclass(cls)

Expand Down
14 changes: 12 additions & 2 deletions src/springs/rich_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,17 @@
import re
from argparse import SUPPRESS, ArgumentParser
from dataclasses import dataclass
from typing import IO, Any, Dict, Generator, List, Optional, Sequence, Union
from typing import (
IO,
Any,
Dict,
Generator,
List,
Optional,
Sequence,
Tuple,
Union,
)

from omegaconf import DictConfig, ListConfig
from rich import box
Expand Down Expand Up @@ -153,7 +163,7 @@ def format_usage(self):
for ag in self._action_groups:
for act in ag._group_actions:
if isinstance(act.metavar, str):
metavar = (act.metavar,)
metavar: Tuple[str, ...] = (act.metavar,)
elif act.metavar is None:
metavar = (act.dest.upper(),)
else:
Expand Down
30 changes: 4 additions & 26 deletions src/springs/shortcuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,37 +73,14 @@ def make_flexy(cls_: Any) -> Any:
return flexyclass(cls_)


def fval(value: T, **kwargs) -> T:
"""Shortcut for creating a Field with a default value.

Args:
value: value returned by default factory"""

return field(default=value, **kwargs)


def fobj(object: T, **kwargs) -> T:
"""Shortcut for creating a Field with a default_factory that returns
a specific object.

Args:
obj: object returned by default factory"""

def _factory_fn() -> T:
# make a copy so that the same object isn't returned
# (it's a factory, not a singleton!)
return copy.deepcopy(object)

return field(default_factory=_factory_fn, **kwargs)


def fdict(**kwargs: Any) -> Dict[str, Any]:
"""Shortcut for creating a Field with a default_factory that returns
a dictionary.

Args:
**kwargs: values for the dictionary returned by default factory"""
return fobj(kwargs)
kwargs = copy.deepcopy(kwargs)
return field(default_factory=lambda: kwargs)


def flist(*args: Any) -> List[Any]:
Expand All @@ -112,7 +89,8 @@ def flist(*args: Any) -> List[Any]:

Args:
*args: values for the list returned by default factory"""
return fobj(list(args))
l_args = list(copy.deepcopy(args))
return field(default_factory=lambda: l_args)


def debug_logger(*args: Any, **kwargs: Any) -> Logger:
Expand Down
Loading