Skip to content

Commit

Permalink
fix #143, some work on #140
Browse files Browse the repository at this point in the history
  • Loading branch information
bckohan committed Nov 19, 2024
1 parent d1e2d5f commit 9c85ba8
Show file tree
Hide file tree
Showing 11 changed files with 445 additions and 27 deletions.
171 changes: 148 additions & 23 deletions django_typer/management/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,10 @@
from ..utils import ( # noqa: E402
_command_context,
_load_command_plugins,
accepted_kwargs,
called_from_command_definition,
called_from_module,
get_current_command,
get_usage_script,
is_method,
with_typehint,
Expand Down Expand Up @@ -323,13 +325,35 @@ def _common_options(
_common_params: t.Sequence[t.Union[click.Argument, click.Option]] = []


def _get_common_params() -> t.Sequence[t.Union[click.Argument, click.Option]]:
def _normalize_suppressed_arguments(
command: t.Union[t.Type["TyperCommand"], "TyperCommand"],
) -> t.Set[str]:
if command.suppressed_base_arguments:
return set(
[
arg.lstrip("--").replace("-", "_")
for arg in command.suppressed_base_arguments
]
)
return set()


def _get_common_params(
command: t.Type["TyperCommand"],
) -> t.Sequence[t.Union[click.Argument, click.Option]]:
"""Use typer to convert the common options to click options"""
global _common_params
if not _common_params:
_common_params = get_params_convertors_ctx_param_name_from_function(
_common_options
)[0]
suppressed = _normalize_suppressed_arguments(command)
if suppressed:
return [
param
for param in _common_params
if param.name and param.name not in suppressed
]
return _common_params


Expand All @@ -339,6 +363,13 @@ def _get_common_params() -> t.Sequence[t.Union[click.Argument, click.Option]]:
}


def _remove_suppressed(
command: "TyperCommand", params: t.Dict[str, t.Any], manual: t.Set[str] = set()
) -> t.Dict[str, t.Any]:
suppressed = _normalize_suppressed_arguments(command)
return {k: v for k, v in params.items() if k not in suppressed or k in manual}


class _ParsedArgs(SimpleNamespace):
"""
Emulate the argparse.Namespace class so that we can pass the parsed arguments
Expand All @@ -349,9 +380,13 @@ def __init__(self, args: t.Sequence[t.Any], **kwargs: t.Any):
super().__init__(**kwargs)
self.args = args
self.traceback = kwargs.get("traceback", TyperCommand._traceback)
if not hasattr(self, "pythonpath"):
self.pythonpath = None
if not hasattr(self, "settings"):
self.settings = None

def _get_kwargs(self):
return {**COMMON_DEFAULTS, **vars(self)}
return vars(self)


class Context(TyperContext):
Expand Down Expand Up @@ -404,14 +439,30 @@ def __init__(
**kwargs: t.Any,
):
super().__init__(command, parent=parent, **kwargs)
if supplied_params:
self._supplied_params = supplied_params
if django_command:
self.django_command = django_command
if supplied_params:
# if we're at the top level, default django parameters that
# were suppressed may have been injected into execute() and
# wound up here. We remove them to keep the interface honest
supplied_params = _remove_suppressed(
self.django_command,
supplied_params,
{
param.name
for param in get_typer_command(
self.django_command.typer_app
).params
if param.name
},
)
else:
assert parent
self.django_command = parent.django_command

if supplied_params:
self._supplied_params = supplied_params

self.params = self.ParamDict(
{**self.params, **self.supplied_params},
supplied=list(self.supplied_params.keys()),
Expand All @@ -428,7 +479,7 @@ class DjangoTyperMixin(with_typehint(CoreTyperGroup)): # type: ignore[misc]
"""

context_class: t.Type[click.Context] = Context
django_command: "TyperCommand"
django_command: t.Type["TyperCommand"]
_callback: t.Optional[t.Callable[..., t.Any]] = None
_callback_is_method: t.Optional[bool] = None
common_init: bool = False
Expand Down Expand Up @@ -502,17 +553,8 @@ def common_params(self) -> t.Sequence[t.Union[click.Argument, click.Option]]:
Add the common parameters to this group only if this group is the root
command's user specified initialize callback.
"""
suppressed = getattr(self.django_command, "suppressed_base_arguments", None)
return (
[
param
for param in _get_common_params()
if param.name
and param.name
not in (
{arg.lstrip("--").replace("-", "_") for arg in suppressed or []}
)
]
_get_common_params(self.django_command)
if self.common_init or self.no_callback
else []
)
Expand Down Expand Up @@ -637,6 +679,37 @@ def list_commands(self, ctx: click.Context) -> t.List[str]:
# hence the following mishegoss


class Finalizer(t.Generic[P, R]):
"""
A class that wraps a finalizer function and makes it callable while passing the
django command instance if expected.
"""

finalizer: t.Callable[P, R]
is_method: bool

def __init__(self, finalizer: t.Callable[P, R]):
self.finalizer = finalizer
self.is_method = bool(is_method(finalizer))

def __call__(
self,
*args: P.args,
**kwargs: P.kwargs,
) -> R:
if self.is_method:
cmd = kwargs.pop("_command", None) or (
getattr(
t.cast(Context, click.get_current_context(silent=True)),
"django_command",
None,
)
or get_current_command()
)
args = [cmd, *args] # type: ignore
return self.finalizer(*args, **accepted_kwargs(self.finalizer, kwargs))


@t.overload # pragma: no cover
def _check_static(
func: typer.models.CommandFunctionType,
Expand Down Expand Up @@ -682,7 +755,7 @@ def _cache_initializer(
**kwargs: t.Any,
):
def register(
cmd: "TyperCommand",
cmd: t.Type["TyperCommand"],
_name: t.Optional[str] = Default(None),
_help: t.Optional[t.Union[str, Promise]] = Default(None),
**extra,
Expand All @@ -702,6 +775,14 @@ def register(
setattr(callback, _CACHE_KEY, register)


def _cache_finalizer(callback: t.Callable[..., t.Any]):
def register(cmd: t.Type["TyperCommand"]):
finalizer = Finalizer(_strip_static(callback))
cmd.typer_app.info.result_callback = finalizer

setattr(callback, _CACHE_KEY, register)


def _cache_command(
callback: t.Callable[..., t.Any],
name: t.Optional[str] = None,
Expand All @@ -710,7 +791,7 @@ def _cache_command(
**kwargs: t.Any,
):
def register(
cmd: "TyperCommand",
cmd: t.Type["TyperCommand"],
_name: t.Optional[str] = None,
_help: t.Optional[t.Union[str, Promise]] = None,
**extra,
Expand Down Expand Up @@ -1319,9 +1400,6 @@ def __getattr__(self, name: str) -> t.Any:
)
)

# def __repr__(self) -> str:
# return f'<{self.__class__.__module__}.{self.__class__.__name__} for {repr(self.proxied)}>'


def initialize(
name: t.Optional[str] = Default(None),
Expand Down Expand Up @@ -1468,6 +1546,19 @@ def make_initializer(func: t.Callable[P2, R2]) -> t.Callable[P2, R2]:
callback = initialize # allow callback as an alias


def finalize() -> t.Callable[[t.Callable[P2, R2]], t.Callable[P2, R2]]:
"""
TODO
"""

def make_finalizer(func: t.Callable[P2, R2]) -> t.Callable[P2, R2]:
func = _check_static(func)
_cache_finalizer(func)
return func

return make_finalizer


def command(
name: t.Optional[str] = None,
*,
Expand Down Expand Up @@ -1878,6 +1969,7 @@ class TyperCommandMeta(type):
typer_app: Typer
no_color: bool
force_color: bool
skip_checks: bool
is_compound_command: bool
_handle: t.Optional[t.Callable[..., t.Any]]

Expand Down Expand Up @@ -2332,7 +2424,14 @@ def parse_args(self, args=None, namespace=None) -> _ParsedArgs:
django_command=self.django_command,
args=list(args or []),
) as ctx:
common = {**COMMON_DEFAULTS, **ctx.params}
common = {
**_remove_suppressed(
self.django_command,
COMMON_DEFAULTS,
{param.name for param in cmd.params if param.name},
),
**ctx.params,
}
self.django_command._traceback = common.get(
"traceback", self.django_command._traceback
)
Expand Down Expand Up @@ -2500,10 +2599,13 @@ def command2(self, option: t.Optional[str] = None):
typer_app: Typer
no_color: bool = False
force_color: bool = False
skip_checks: bool = False
_handle: t.Callable[..., t.Any]
_traceback: bool = False
_help_kwarg: t.Optional[str] = Default(None)
_defined_groups: t.Dict[str, Typer] = {}
_finalizer: t.Optional[Finalizer] = None
_suppressed_base_arguments: t.Optional[t.Set[str]] = None

help: t.Optional[t.Union[DefaultPlaceholder, str, Promise]] = Default(None) # type: ignore

Expand Down Expand Up @@ -3065,14 +3167,23 @@ def _run(self, *args, **options):
:return: t.Any object returned by the Typer app
"""
with self:
return self.typer_app(
result = self.typer_app(
args=args,
standalone_mode=False,
supplied_params=options,
django_command=self,
complete_var=None,
prog_name=f"{sys.argv[0]} {self.typer_app.info.name}",
)
if not self.is_compound_command and isinstance(
self.typer_app.info.result_callback, Finalizer
):
# result callbacks are not called on singular commands by click/typer
# we do that here to keep our interface consistent
return self.typer_app.info.result_callback(
result, **options, _command=self
)
return result

def run_from_argv(self, argv):
"""
Expand Down Expand Up @@ -3102,16 +3213,30 @@ def execute(self, *args, **options):
"""
no_color = self.no_color
force_color = self.force_color
skip_checks = self.skip_checks
if options.get("no_color", None) is not None:
self.no_color = options["no_color"]
if options.get("force_color", None) is not None:
self.force_color = options["force_color"]
if options.get("skip_checks", None) is not None:
self.skip_checks = options["skip_checks"]
try:
with self:
return super().execute(*args, **options)
# base class requires force_color, no_color and skip_checks to be present - we
# allow them to be suppressed
return super().execute(
*args,
**{
"force_color": self.force_color,
"no_color": self.no_color,
"skip_checks": self.skip_checks,
**options,
},
)
finally:
self.no_color = no_color
self.force_color = force_color
self.skip_checks = skip_checks

def echo(
self, message: t.Optional[t.Any] = None, nl: bool = True, err: bool = False
Expand Down
21 changes: 21 additions & 0 deletions django_typer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,3 +196,24 @@ def is_method(
return params[0] == "self"
return isinstance(func_or_params, MethodType)
return None


def accepts_var_kwargs(func: t.Callable[..., t.Any]) -> bool:
"""
Determines if the given function accepts variable keyword arguments.
"""
for param in reversed(inspect.signature(func).parameters.values()):
return param.kind is inspect.Parameter.VAR_KEYWORD
return False


def accepted_kwargs(
func: t.Callable[..., t.Any], kwargs: t.Dict[str, t.Any]
) -> t.Dict[str, t.Any]:
"""
Return the named keyword arguments that are accepted by the given function.
"""
if accepts_var_kwargs(func):
return kwargs
param_names = set(inspect.signature(func).parameters.keys())
return {k: v for k, v in kwargs.items() if k in param_names}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from django_typer.management import TyperCommand, COMMON_DEFAULTS
from typer.models import Context as TyperContext


class Command(TyperCommand):
help = "Test that django parameter suppression works as expected"

suppressed_base_arguments = set(COMMON_DEFAULTS.keys())

def handle(self, ctx: TyperContext):
assert self.__class__ is Command
assert isinstance(ctx, TyperContext)
assert not ctx.params
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from django_typer.management import TyperCommand, COMMON_DEFAULTS, initialize, command
from typer.models import Context as TyperContext


class Command(TyperCommand):
help = "Test that django parameter suppression works as expected"

suppressed_base_arguments = set(COMMON_DEFAULTS.keys())

@initialize()
def init(self, ctx: TyperContext):
assert self.__class__ is Command
assert isinstance(ctx, TyperContext)
assert not ctx.params

@command(name="cmd")
def handle(self, ctx: TyperContext):
assert self.__class__ is Command
assert isinstance(ctx, TyperContext)
assert not ctx.params
Loading

0 comments on commit 9c85ba8

Please sign in to comment.